xref: /aosp_15_r20/external/armnn/src/backends/neon/test/NeonFallbackTests.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2020-2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include <CommonTestUtils.hpp>
7*89c4ff92SAndroid Build Coastguard Worker #include <backendsCommon/test/mockBackend/MockImportBackend.hpp>
8*89c4ff92SAndroid Build Coastguard Worker 
9*89c4ff92SAndroid Build Coastguard Worker #include <GraphUtils.hpp>
10*89c4ff92SAndroid Build Coastguard Worker 
11*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
12*89c4ff92SAndroid Build Coastguard Worker 
13*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("NeonFallback")
14*89c4ff92SAndroid Build Coastguard Worker {
15*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("FallbackImportToCpuAcc")
16*89c4ff92SAndroid Build Coastguard Worker {
17*89c4ff92SAndroid Build Coastguard Worker     using namespace armnn;
18*89c4ff92SAndroid Build Coastguard Worker 
19*89c4ff92SAndroid Build Coastguard Worker     // Create a mock backend objectN
20*89c4ff92SAndroid Build Coastguard Worker     MockImportBackendInitialiser initialiser; // Register the Mock Backend
21*89c4ff92SAndroid Build Coastguard Worker     auto backendObjPtr = CreateBackendObject(MockImportBackendId());
22*89c4ff92SAndroid Build Coastguard Worker     CHECK((backendObjPtr != nullptr));
23*89c4ff92SAndroid Build Coastguard Worker 
24*89c4ff92SAndroid Build Coastguard Worker     BackendIdSet backendIds = BackendRegistryInstance().GetBackendIds();
25*89c4ff92SAndroid Build Coastguard Worker     if (backendIds.find("MockRef") == backendIds.end())
26*89c4ff92SAndroid Build Coastguard Worker     {
27*89c4ff92SAndroid Build Coastguard Worker         std::string message = "Cannot load MockRef";
28*89c4ff92SAndroid Build Coastguard Worker         FAIL(message);
29*89c4ff92SAndroid Build Coastguard Worker     }
30*89c4ff92SAndroid Build Coastguard Worker 
31*89c4ff92SAndroid Build Coastguard Worker     // Create runtime in which test will run and allow fallback to CpuRef.
32*89c4ff92SAndroid Build Coastguard Worker     IRuntime::CreationOptions options;
33*89c4ff92SAndroid Build Coastguard Worker     IRuntimePtr runtime(IRuntime::Create(options));
34*89c4ff92SAndroid Build Coastguard Worker 
35*89c4ff92SAndroid Build Coastguard Worker     // Builds up the structure of the network.
36*89c4ff92SAndroid Build Coastguard Worker     INetworkPtr net(INetwork::Create());
37*89c4ff92SAndroid Build Coastguard Worker 
38*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* input0 = net->AddInputLayer(0, "input0");
39*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* input1 = net->AddInputLayer(1, "input1");
40*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* input2 = net->AddInputLayer(2, "input2");
41*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* add = net->AddElementwiseBinaryLayer(BinaryOperation::Add, "add");
42*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* sub = net->AddElementwiseBinaryLayer(BinaryOperation::Sub, "sub");
43*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* output = net->AddOutputLayer(0, "output");
44*89c4ff92SAndroid Build Coastguard Worker 
45*89c4ff92SAndroid Build Coastguard Worker     input0->GetOutputSlot(0).Connect(add->GetInputSlot(0));
46*89c4ff92SAndroid Build Coastguard Worker     input1->GetOutputSlot(0).Connect(add->GetInputSlot(1));
47*89c4ff92SAndroid Build Coastguard Worker     input2->GetOutputSlot(0).Connect(sub->GetInputSlot(0));
48*89c4ff92SAndroid Build Coastguard Worker     add->GetOutputSlot(0).Connect(sub->GetInputSlot(1));
49*89c4ff92SAndroid Build Coastguard Worker     sub->GetOutputSlot(0).Connect(output->GetInputSlot(0));
50*89c4ff92SAndroid Build Coastguard Worker 
51*89c4ff92SAndroid Build Coastguard Worker     TensorInfo info = TensorInfo({ 1, 2, 3, 2 }, DataType::Float32);
52*89c4ff92SAndroid Build Coastguard Worker 
53*89c4ff92SAndroid Build Coastguard Worker     input0->GetOutputSlot(0).SetTensorInfo(info);
54*89c4ff92SAndroid Build Coastguard Worker     input1->GetOutputSlot(0).SetTensorInfo(info);
55*89c4ff92SAndroid Build Coastguard Worker     input2->GetOutputSlot(0).SetTensorInfo(info);
56*89c4ff92SAndroid Build Coastguard Worker     add->GetOutputSlot(0).SetTensorInfo(info);
57*89c4ff92SAndroid Build Coastguard Worker     sub->GetOutputSlot(0).SetTensorInfo(info);
58*89c4ff92SAndroid Build Coastguard Worker 
59*89c4ff92SAndroid Build Coastguard Worker     // optimize the network
60*89c4ff92SAndroid Build Coastguard Worker     std::vector<BackendId> backends = { "MockRef", Compute::CpuAcc };
61*89c4ff92SAndroid Build Coastguard Worker     OptimizerOptionsOpaque optOptions;
62*89c4ff92SAndroid Build Coastguard Worker     optOptions.SetImportEnabled(true);
63*89c4ff92SAndroid Build Coastguard Worker     optOptions.SetExportEnabled(true);
64*89c4ff92SAndroid Build Coastguard Worker     IOptimizedNetworkPtr optNet = Optimize(*net, backends, runtime->GetDeviceSpec(), optOptions);
65*89c4ff92SAndroid Build Coastguard Worker 
66*89c4ff92SAndroid Build Coastguard Worker     Graph& graph = GetGraphForTesting(optNet.get());
67*89c4ff92SAndroid Build Coastguard Worker 
68*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer0 = GetFirstLayerWithName(graph, "input0");
69*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer1 = GetFirstLayerWithName(graph, "input1");
70*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer2 = GetFirstLayerWithName(graph, "input2");
71*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer3 = GetFirstLayerWithName(graph, "add");
72*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer4 = GetFirstLayerWithName(graph, "[ add (0) -> sub (1) ]");
73*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer5 = GetFirstLayerWithName(graph, "sub");
74*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer6 = GetFirstLayerWithName(graph, "output");
75*89c4ff92SAndroid Build Coastguard Worker 
76*89c4ff92SAndroid Build Coastguard Worker     // Checks order is valid.
77*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer0, layer1));
78*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer1, layer2));
79*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer2, layer3));
80*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer3, layer4));
81*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer4, layer5));
82*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer5, layer6));
83*89c4ff92SAndroid Build Coastguard Worker 
84*89c4ff92SAndroid Build Coastguard Worker     // Load it into the runtime. It should pass.
85*89c4ff92SAndroid Build Coastguard Worker     NetworkId netId;
86*89c4ff92SAndroid Build Coastguard Worker     std::string ignoredErrorMessage;
87*89c4ff92SAndroid Build Coastguard Worker     INetworkProperties networkProperties(false, MemorySource::Malloc, MemorySource::Malloc);
88*89c4ff92SAndroid Build Coastguard Worker     runtime->LoadNetwork(netId, std::move(optNet), ignoredErrorMessage, networkProperties);
89*89c4ff92SAndroid Build Coastguard Worker 
90*89c4ff92SAndroid Build Coastguard Worker     // Creates structures for input & output
91*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputData0
92*89c4ff92SAndroid Build Coastguard Worker     {
93*89c4ff92SAndroid Build Coastguard Worker         1.0f, 1.0f, 2.0f, 2.0f, 2.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f, 6.0f, 6.0f
94*89c4ff92SAndroid Build Coastguard Worker     };
95*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputData1
96*89c4ff92SAndroid Build Coastguard Worker     {
97*89c4ff92SAndroid Build Coastguard Worker         0.0f, 1.0f, 1.0f, 2.0f, 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f, 6.0f
98*89c4ff92SAndroid Build Coastguard Worker     };
99*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputData2
100*89c4ff92SAndroid Build Coastguard Worker     {
101*89c4ff92SAndroid Build Coastguard Worker         12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f
102*89c4ff92SAndroid Build Coastguard Worker     };
103*89c4ff92SAndroid Build Coastguard Worker 
104*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> outputData(12);
105*89c4ff92SAndroid Build Coastguard Worker 
106*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expectedOutput
107*89c4ff92SAndroid Build Coastguard Worker     {
108*89c4ff92SAndroid Build Coastguard Worker         11.0f, 9.0f, 7.0f, 5.0f, 3.0f, 1.0f, -1.0f, -3.0f, -5.0f, -7.0f, -9.0f, -11.0f
109*89c4ff92SAndroid Build Coastguard Worker     };
110*89c4ff92SAndroid Build Coastguard Worker 
111*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo0 = runtime->GetInputTensorInfo(netId, 0);
112*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo1 = runtime->GetInputTensorInfo(netId, 1);
113*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo2 = runtime->GetInputTensorInfo(netId, 2);
114*89c4ff92SAndroid Build Coastguard Worker     inputTensorInfo0.SetConstant(true);
115*89c4ff92SAndroid Build Coastguard Worker     inputTensorInfo1.SetConstant(true);
116*89c4ff92SAndroid Build Coastguard Worker     inputTensorInfo2.SetConstant(true);
117*89c4ff92SAndroid Build Coastguard Worker 
118*89c4ff92SAndroid Build Coastguard Worker     InputTensors inputTensors
119*89c4ff92SAndroid Build Coastguard Worker     {
120*89c4ff92SAndroid Build Coastguard Worker         { 0, armnn::ConstTensor(inputTensorInfo0, inputData0.data()) },
121*89c4ff92SAndroid Build Coastguard Worker         { 1, armnn::ConstTensor(inputTensorInfo1, inputData1.data()) },
122*89c4ff92SAndroid Build Coastguard Worker         { 2, armnn::ConstTensor(inputTensorInfo2, inputData2.data()) }
123*89c4ff92SAndroid Build Coastguard Worker     };
124*89c4ff92SAndroid Build Coastguard Worker     OutputTensors outputTensors
125*89c4ff92SAndroid Build Coastguard Worker     {
126*89c4ff92SAndroid Build Coastguard Worker         { 0,armnn::Tensor(runtime->GetOutputTensorInfo(netId, 0), outputData.data()) }
127*89c4ff92SAndroid Build Coastguard Worker     };
128*89c4ff92SAndroid Build Coastguard Worker 
129*89c4ff92SAndroid Build Coastguard Worker     runtime->GetProfiler(netId)->EnableProfiling(true);
130*89c4ff92SAndroid Build Coastguard Worker 
131*89c4ff92SAndroid Build Coastguard Worker     // Do the inference
132*89c4ff92SAndroid Build Coastguard Worker     runtime->EnqueueWorkload(netId, inputTensors, outputTensors);
133*89c4ff92SAndroid Build Coastguard Worker 
134*89c4ff92SAndroid Build Coastguard Worker     // Retrieve the Profiler.Print() output to get the workload execution
135*89c4ff92SAndroid Build Coastguard Worker     ProfilerManager& profilerManager = armnn::ProfilerManager::GetInstance();
136*89c4ff92SAndroid Build Coastguard Worker     std::stringstream ss;
137*89c4ff92SAndroid Build Coastguard Worker     profilerManager.GetProfiler()->Print(ss);;
138*89c4ff92SAndroid Build Coastguard Worker     std::string dump = ss.str();
139*89c4ff92SAndroid Build Coastguard Worker 
140*89c4ff92SAndroid Build Coastguard Worker     // Contains ImportMemGeneric
141*89c4ff92SAndroid Build Coastguard Worker     std::size_t found = dump.find("ImportMemGeneric");
142*89c4ff92SAndroid Build Coastguard Worker     CHECK(found != std::string::npos);
143*89c4ff92SAndroid Build Coastguard Worker 
144*89c4ff92SAndroid Build Coastguard Worker     // Contains SyncMemGeneric
145*89c4ff92SAndroid Build Coastguard Worker     found = dump.find("SyncMemGeneric");
146*89c4ff92SAndroid Build Coastguard Worker     CHECK(found != std::string::npos);
147*89c4ff92SAndroid Build Coastguard Worker 
148*89c4ff92SAndroid Build Coastguard Worker     // Does not contain CopyMemGeneric
149*89c4ff92SAndroid Build Coastguard Worker     found = dump.find("CopyMemGeneric");
150*89c4ff92SAndroid Build Coastguard Worker     CHECK(found == std::string::npos);
151*89c4ff92SAndroid Build Coastguard Worker 
152*89c4ff92SAndroid Build Coastguard Worker     // Use memory import between backends
153*89c4ff92SAndroid Build Coastguard Worker     CHECK((layer4->GetType() == LayerType::MemImport));
154*89c4ff92SAndroid Build Coastguard Worker 
155*89c4ff92SAndroid Build Coastguard Worker     // Check output is as expected
156*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputData == expectedOutput);
157*89c4ff92SAndroid Build Coastguard Worker }
158*89c4ff92SAndroid Build Coastguard Worker 
159*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("FallbackPaddingCopyToCpuAcc")
160*89c4ff92SAndroid Build Coastguard Worker {
161*89c4ff92SAndroid Build Coastguard Worker     using namespace armnn;
162*89c4ff92SAndroid Build Coastguard Worker 
163*89c4ff92SAndroid Build Coastguard Worker     // Create a mock backend object
164*89c4ff92SAndroid Build Coastguard Worker     MockImportBackendInitialiser initialiser; // Register the Mock Backend
165*89c4ff92SAndroid Build Coastguard Worker     auto backendObjPtr = CreateBackendObject(MockImportBackendId());
166*89c4ff92SAndroid Build Coastguard Worker     CHECK((backendObjPtr != nullptr));
167*89c4ff92SAndroid Build Coastguard Worker 
168*89c4ff92SAndroid Build Coastguard Worker     BackendIdSet backendIds = BackendRegistryInstance().GetBackendIds();
169*89c4ff92SAndroid Build Coastguard Worker     if (backendIds.find("MockRef") == backendIds.end())
170*89c4ff92SAndroid Build Coastguard Worker     {
171*89c4ff92SAndroid Build Coastguard Worker         std::string message = "Cannot load MockRef";
172*89c4ff92SAndroid Build Coastguard Worker         FAIL(message);
173*89c4ff92SAndroid Build Coastguard Worker     }
174*89c4ff92SAndroid Build Coastguard Worker 
175*89c4ff92SAndroid Build Coastguard Worker     // Create runtime in which test will run and allow fallback to CpuRef.
176*89c4ff92SAndroid Build Coastguard Worker     IRuntime::CreationOptions options;
177*89c4ff92SAndroid Build Coastguard Worker     IRuntimePtr runtime(IRuntime::Create(options));
178*89c4ff92SAndroid Build Coastguard Worker 
179*89c4ff92SAndroid Build Coastguard Worker     // Builds up the structure of the network.
180*89c4ff92SAndroid Build Coastguard Worker     INetworkPtr net(INetwork::Create());
181*89c4ff92SAndroid Build Coastguard Worker 
182*89c4ff92SAndroid Build Coastguard Worker     Pooling2dDescriptor desc;
183*89c4ff92SAndroid Build Coastguard Worker 
184*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* input0 = net->AddInputLayer(0, "input0");
185*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* input1 = net->AddInputLayer(1, "input1");
186*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* add = net->AddElementwiseBinaryLayer(BinaryOperation::Add, "add");
187*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* pooling = net->AddPooling2dLayer(desc, "pooling");
188*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* output = net->AddOutputLayer(0, "output");
189*89c4ff92SAndroid Build Coastguard Worker 
190*89c4ff92SAndroid Build Coastguard Worker     input0->GetOutputSlot(0).Connect(add->GetInputSlot(0));
191*89c4ff92SAndroid Build Coastguard Worker     input1->GetOutputSlot(0).Connect(add->GetInputSlot(1));
192*89c4ff92SAndroid Build Coastguard Worker     add->GetOutputSlot(0).Connect(pooling->GetInputSlot(0));
193*89c4ff92SAndroid Build Coastguard Worker     pooling->GetOutputSlot(0).Connect(output->GetInputSlot(0));
194*89c4ff92SAndroid Build Coastguard Worker 
195*89c4ff92SAndroid Build Coastguard Worker     TensorInfo info = TensorInfo({ 1, 2, 3, 2 }, DataType::Float32);
196*89c4ff92SAndroid Build Coastguard Worker     TensorInfo poolingInfo = TensorInfo({ 1, 2, 1, 1 }, DataType::Float32);
197*89c4ff92SAndroid Build Coastguard Worker 
198*89c4ff92SAndroid Build Coastguard Worker     input0->GetOutputSlot(0).SetTensorInfo(info);
199*89c4ff92SAndroid Build Coastguard Worker     input1->GetOutputSlot(0).SetTensorInfo(info);
200*89c4ff92SAndroid Build Coastguard Worker     add->GetOutputSlot(0).SetTensorInfo(info);
201*89c4ff92SAndroid Build Coastguard Worker     pooling->GetOutputSlot(0).SetTensorInfo(poolingInfo);
202*89c4ff92SAndroid Build Coastguard Worker 
203*89c4ff92SAndroid Build Coastguard Worker     // optimize the network
204*89c4ff92SAndroid Build Coastguard Worker     std::vector<BackendId> backends = { "MockRef", Compute::CpuAcc };
205*89c4ff92SAndroid Build Coastguard Worker     OptimizerOptionsOpaque optOptions;
206*89c4ff92SAndroid Build Coastguard Worker     optOptions.SetImportEnabled(true);
207*89c4ff92SAndroid Build Coastguard Worker     optOptions.SetExportEnabled(true);
208*89c4ff92SAndroid Build Coastguard Worker     IOptimizedNetworkPtr optNet = Optimize(*net, backends, runtime->GetDeviceSpec(), optOptions);
209*89c4ff92SAndroid Build Coastguard Worker 
210*89c4ff92SAndroid Build Coastguard Worker     Graph& graph = GetGraphForTesting(optNet.get());
211*89c4ff92SAndroid Build Coastguard Worker 
212*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer0 = GetFirstLayerWithName(graph, "input0");
213*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer1 = GetFirstLayerWithName(graph, "input1");
214*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer2 = GetFirstLayerWithName(graph, "add");
215*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer3 = GetFirstLayerWithName(graph, "[ add (0) -> pooling (0) ]");
216*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer4 = GetFirstLayerWithName(graph, "pooling");
217*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer5 = GetFirstLayerWithName(graph, "output");
218*89c4ff92SAndroid Build Coastguard Worker 
219*89c4ff92SAndroid Build Coastguard Worker     // Checks order is valid.
220*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer0, layer1));
221*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer1, layer2));
222*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer2, layer3));
223*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer3, layer4));
224*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer4, layer5));
225*89c4ff92SAndroid Build Coastguard Worker 
226*89c4ff92SAndroid Build Coastguard Worker     // Load it into the runtime. It should pass.
227*89c4ff92SAndroid Build Coastguard Worker     NetworkId netId;
228*89c4ff92SAndroid Build Coastguard Worker     std::string ignoredErrorMessage;
229*89c4ff92SAndroid Build Coastguard Worker     INetworkProperties networkProperties(false, MemorySource::Malloc, MemorySource::Malloc);
230*89c4ff92SAndroid Build Coastguard Worker 
231*89c4ff92SAndroid Build Coastguard Worker     runtime->LoadNetwork(netId, std::move(optNet), ignoredErrorMessage, networkProperties);
232*89c4ff92SAndroid Build Coastguard Worker 
233*89c4ff92SAndroid Build Coastguard Worker     // Creates structures for input & output
234*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputData0
235*89c4ff92SAndroid Build Coastguard Worker     {
236*89c4ff92SAndroid Build Coastguard Worker         1.0f, 1.0f, 2.0f, 2.0f, 2.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f, 6.0f, 6.0f
237*89c4ff92SAndroid Build Coastguard Worker     };
238*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputData1
239*89c4ff92SAndroid Build Coastguard Worker     {
240*89c4ff92SAndroid Build Coastguard Worker         0.0f, 1.0f, 1.0f, 2.0f, 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f, 6.0f
241*89c4ff92SAndroid Build Coastguard Worker     };
242*89c4ff92SAndroid Build Coastguard Worker 
243*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> outputData(2);
244*89c4ff92SAndroid Build Coastguard Worker 
245*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expectedOutput
246*89c4ff92SAndroid Build Coastguard Worker     {
247*89c4ff92SAndroid Build Coastguard Worker         6.0f, 12.0f
248*89c4ff92SAndroid Build Coastguard Worker     };
249*89c4ff92SAndroid Build Coastguard Worker 
250*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo0 = runtime->GetInputTensorInfo(netId, 0);
251*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo1 = runtime->GetInputTensorInfo(netId, 1);
252*89c4ff92SAndroid Build Coastguard Worker     inputTensorInfo0.SetConstant(true);
253*89c4ff92SAndroid Build Coastguard Worker     inputTensorInfo1.SetConstant(true);
254*89c4ff92SAndroid Build Coastguard Worker 
255*89c4ff92SAndroid Build Coastguard Worker     InputTensors inputTensors
256*89c4ff92SAndroid Build Coastguard Worker     {
257*89c4ff92SAndroid Build Coastguard Worker         { 0, armnn::ConstTensor(inputTensorInfo0, inputData0.data()) },
258*89c4ff92SAndroid Build Coastguard Worker         { 1, armnn::ConstTensor(inputTensorInfo1, inputData1.data()) }
259*89c4ff92SAndroid Build Coastguard Worker     };
260*89c4ff92SAndroid Build Coastguard Worker     OutputTensors outputTensors
261*89c4ff92SAndroid Build Coastguard Worker     {
262*89c4ff92SAndroid Build Coastguard Worker         { 0, armnn::Tensor(runtime->GetOutputTensorInfo(netId, 0), outputData.data()) }
263*89c4ff92SAndroid Build Coastguard Worker     };
264*89c4ff92SAndroid Build Coastguard Worker 
265*89c4ff92SAndroid Build Coastguard Worker     runtime->GetProfiler(netId)->EnableProfiling(true);
266*89c4ff92SAndroid Build Coastguard Worker 
267*89c4ff92SAndroid Build Coastguard Worker     // Do the inference
268*89c4ff92SAndroid Build Coastguard Worker     runtime->EnqueueWorkload(netId, inputTensors, outputTensors);
269*89c4ff92SAndroid Build Coastguard Worker 
270*89c4ff92SAndroid Build Coastguard Worker     // Retrieve the Profiler.Print() output to get the workload execution
271*89c4ff92SAndroid Build Coastguard Worker     ProfilerManager& profilerManager = armnn::ProfilerManager::GetInstance();
272*89c4ff92SAndroid Build Coastguard Worker     std::stringstream ss;
273*89c4ff92SAndroid Build Coastguard Worker     profilerManager.GetProfiler()->Print(ss);;
274*89c4ff92SAndroid Build Coastguard Worker     std::string dump = ss.str();
275*89c4ff92SAndroid Build Coastguard Worker 
276*89c4ff92SAndroid Build Coastguard Worker     // Contains CopyMemGeneric between the backends
277*89c4ff92SAndroid Build Coastguard Worker     std::size_t found = dump.find("CopyMemGeneric");
278*89c4ff92SAndroid Build Coastguard Worker     CHECK(found != std::string::npos);
279*89c4ff92SAndroid Build Coastguard Worker 
280*89c4ff92SAndroid Build Coastguard Worker     // Contains SyncMemGeneric for the output
281*89c4ff92SAndroid Build Coastguard Worker     found = dump.find("SyncMemGeneric");
282*89c4ff92SAndroid Build Coastguard Worker     CHECK(found != std::string::npos);
283*89c4ff92SAndroid Build Coastguard Worker 
284*89c4ff92SAndroid Build Coastguard Worker     // Does not contain ImportMemGeneric
285*89c4ff92SAndroid Build Coastguard Worker     found = dump.find("ImportMemGeneric");
286*89c4ff92SAndroid Build Coastguard Worker     CHECK(found == std::string::npos);
287*89c4ff92SAndroid Build Coastguard Worker 
288*89c4ff92SAndroid Build Coastguard Worker     // Use memory import between backends
289*89c4ff92SAndroid Build Coastguard Worker     CHECK((layer3->GetType() == LayerType::MemCopy));
290*89c4ff92SAndroid Build Coastguard Worker 
291*89c4ff92SAndroid Build Coastguard Worker     // Check output is as expected
292*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputData == expectedOutput);
293*89c4ff92SAndroid Build Coastguard Worker }
294*89c4ff92SAndroid Build Coastguard Worker 
295*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("FallbackImportFromCpuAcc")
296*89c4ff92SAndroid Build Coastguard Worker {
297*89c4ff92SAndroid Build Coastguard Worker     using namespace armnn;
298*89c4ff92SAndroid Build Coastguard Worker 
299*89c4ff92SAndroid Build Coastguard Worker     // Create a mock backend object
300*89c4ff92SAndroid Build Coastguard Worker     MockImportBackendInitialiser initialiser; // Register the Mock Backend
301*89c4ff92SAndroid Build Coastguard Worker     auto backendObjPtr = CreateBackendObject(MockImportBackendId());
302*89c4ff92SAndroid Build Coastguard Worker     CHECK((backendObjPtr != nullptr));
303*89c4ff92SAndroid Build Coastguard Worker 
304*89c4ff92SAndroid Build Coastguard Worker     BackendIdSet backendIds = BackendRegistryInstance().GetBackendIds();
305*89c4ff92SAndroid Build Coastguard Worker     if (backendIds.find("MockRef") == backendIds.end())
306*89c4ff92SAndroid Build Coastguard Worker     {
307*89c4ff92SAndroid Build Coastguard Worker         std::string message = "Cannot load MockRef";
308*89c4ff92SAndroid Build Coastguard Worker         FAIL(message);
309*89c4ff92SAndroid Build Coastguard Worker     }
310*89c4ff92SAndroid Build Coastguard Worker 
311*89c4ff92SAndroid Build Coastguard Worker     // Create runtime in which test will run and allow fallback to CpuRef.
312*89c4ff92SAndroid Build Coastguard Worker     IRuntime::CreationOptions options;
313*89c4ff92SAndroid Build Coastguard Worker     IRuntimePtr runtime(IRuntime::Create(options));
314*89c4ff92SAndroid Build Coastguard Worker 
315*89c4ff92SAndroid Build Coastguard Worker     // Builds up the structure of the network.
316*89c4ff92SAndroid Build Coastguard Worker     INetworkPtr net(INetwork::Create());
317*89c4ff92SAndroid Build Coastguard Worker 
318*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* input0 = net->AddInputLayer(0, "input0");
319*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* input1 = net->AddInputLayer(1, "input1");
320*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* input2 = net->AddInputLayer(2, "input2");
321*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* sub = net->AddElementwiseBinaryLayer(BinaryOperation::Sub, "sub");
322*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* add = net->AddElementwiseBinaryLayer(BinaryOperation::Add, "add");
323*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* output = net->AddOutputLayer(0, "output");
324*89c4ff92SAndroid Build Coastguard Worker 
325*89c4ff92SAndroid Build Coastguard Worker     input0->GetOutputSlot(0).Connect(sub->GetInputSlot(0));
326*89c4ff92SAndroid Build Coastguard Worker     input1->GetOutputSlot(0).Connect(sub->GetInputSlot(1));
327*89c4ff92SAndroid Build Coastguard Worker     input2->GetOutputSlot(0).Connect(add->GetInputSlot(0));
328*89c4ff92SAndroid Build Coastguard Worker     sub->GetOutputSlot(0).Connect(add->GetInputSlot(1));
329*89c4ff92SAndroid Build Coastguard Worker     add->GetOutputSlot(0).Connect(output->GetInputSlot(0));
330*89c4ff92SAndroid Build Coastguard Worker 
331*89c4ff92SAndroid Build Coastguard Worker     TensorInfo info = TensorInfo({ 1, 2, 3, 2 }, DataType::Float32);
332*89c4ff92SAndroid Build Coastguard Worker 
333*89c4ff92SAndroid Build Coastguard Worker     input0->GetOutputSlot(0).SetTensorInfo(info);
334*89c4ff92SAndroid Build Coastguard Worker     input1->GetOutputSlot(0).SetTensorInfo(info);
335*89c4ff92SAndroid Build Coastguard Worker     input2->GetOutputSlot(0).SetTensorInfo(info);
336*89c4ff92SAndroid Build Coastguard Worker     sub->GetOutputSlot(0).SetTensorInfo(info);
337*89c4ff92SAndroid Build Coastguard Worker     add->GetOutputSlot(0).SetTensorInfo(info);
338*89c4ff92SAndroid Build Coastguard Worker 
339*89c4ff92SAndroid Build Coastguard Worker     // optimize the network
340*89c4ff92SAndroid Build Coastguard Worker     std::vector<BackendId> backends = { "MockRef", Compute::CpuAcc };
341*89c4ff92SAndroid Build Coastguard Worker     OptimizerOptionsOpaque optOptions;
342*89c4ff92SAndroid Build Coastguard Worker     optOptions.SetImportEnabled(true);
343*89c4ff92SAndroid Build Coastguard Worker     optOptions.SetExportEnabled(true);
344*89c4ff92SAndroid Build Coastguard Worker     IOptimizedNetworkPtr optNet = Optimize(*net, backends, runtime->GetDeviceSpec(), optOptions);
345*89c4ff92SAndroid Build Coastguard Worker 
346*89c4ff92SAndroid Build Coastguard Worker     Graph& graph = GetGraphForTesting(optNet.get());
347*89c4ff92SAndroid Build Coastguard Worker 
348*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer0 = GetFirstLayerWithName(graph, "input0");
349*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer1 = GetFirstLayerWithName(graph, "input1");
350*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer2 = GetFirstLayerWithName(graph, "input2");
351*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer3 = GetFirstLayerWithName(graph, "sub");
352*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer4 = GetFirstLayerWithName(graph, "[ sub (0) -> add (1) ]");
353*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer5 = GetFirstLayerWithName(graph, "add");
354*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer6 = GetFirstLayerWithName(graph, "output");
355*89c4ff92SAndroid Build Coastguard Worker 
356*89c4ff92SAndroid Build Coastguard Worker     // Checks order is valid.
357*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer0, layer1));
358*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer1, layer2));
359*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer2, layer3));
360*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer3, layer4));
361*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer4, layer5));
362*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer5, layer6));
363*89c4ff92SAndroid Build Coastguard Worker 
364*89c4ff92SAndroid Build Coastguard Worker     // Load it into the runtime. It should pass.
365*89c4ff92SAndroid Build Coastguard Worker     NetworkId netId;
366*89c4ff92SAndroid Build Coastguard Worker     std::string ignoredErrorMessage;
367*89c4ff92SAndroid Build Coastguard Worker 
368*89c4ff92SAndroid Build Coastguard Worker     INetworkProperties networkProperties(false, MemorySource::Malloc, MemorySource::Malloc);
369*89c4ff92SAndroid Build Coastguard Worker     runtime->LoadNetwork(netId, std::move(optNet), ignoredErrorMessage, networkProperties);
370*89c4ff92SAndroid Build Coastguard Worker 
371*89c4ff92SAndroid Build Coastguard Worker     // Creates structures for input & output
372*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputData0
373*89c4ff92SAndroid Build Coastguard Worker     {
374*89c4ff92SAndroid Build Coastguard Worker         1.0f, 1.0f, 2.0f, 2.0f, 2.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f, 6.0f, 0.0f
375*89c4ff92SAndroid Build Coastguard Worker     };
376*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputData1
377*89c4ff92SAndroid Build Coastguard Worker     {
378*89c4ff92SAndroid Build Coastguard Worker         0.0f, 1.0f, 1.0f, 2.0f, 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f, 6.0f
379*89c4ff92SAndroid Build Coastguard Worker     };
380*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputData2
381*89c4ff92SAndroid Build Coastguard Worker     {
382*89c4ff92SAndroid Build Coastguard Worker         12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f
383*89c4ff92SAndroid Build Coastguard Worker     };
384*89c4ff92SAndroid Build Coastguard Worker 
385*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> outputData(12);
386*89c4ff92SAndroid Build Coastguard Worker 
387*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expectedOutput
388*89c4ff92SAndroid Build Coastguard Worker     {
389*89c4ff92SAndroid Build Coastguard Worker         13.0f, 11.0f, 11.0f, 9.0f, 7.0f, 7.0f, 7.0f, 5.0f, 5.0f, 3.0f, 3.0f, -5.0f
390*89c4ff92SAndroid Build Coastguard Worker     };
391*89c4ff92SAndroid Build Coastguard Worker 
392*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo0 = runtime->GetInputTensorInfo(netId, 0);
393*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo1 = runtime->GetInputTensorInfo(netId, 1);
394*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo2 = runtime->GetInputTensorInfo(netId, 2);
395*89c4ff92SAndroid Build Coastguard Worker     inputTensorInfo0.SetConstant(true);
396*89c4ff92SAndroid Build Coastguard Worker     inputTensorInfo1.SetConstant(true);
397*89c4ff92SAndroid Build Coastguard Worker     inputTensorInfo2.SetConstant(true);
398*89c4ff92SAndroid Build Coastguard Worker 
399*89c4ff92SAndroid Build Coastguard Worker     InputTensors inputTensors
400*89c4ff92SAndroid Build Coastguard Worker     {
401*89c4ff92SAndroid Build Coastguard Worker         { 0, armnn::ConstTensor(inputTensorInfo0, inputData0.data()) },
402*89c4ff92SAndroid Build Coastguard Worker         { 1, armnn::ConstTensor(inputTensorInfo1, inputData1.data()) },
403*89c4ff92SAndroid Build Coastguard Worker         { 2, armnn::ConstTensor(inputTensorInfo2, inputData2.data()) }
404*89c4ff92SAndroid Build Coastguard Worker     };
405*89c4ff92SAndroid Build Coastguard Worker     OutputTensors outputTensors
406*89c4ff92SAndroid Build Coastguard Worker     {
407*89c4ff92SAndroid Build Coastguard Worker         { 0,armnn::Tensor(runtime->GetOutputTensorInfo(netId, 0), outputData.data()) }
408*89c4ff92SAndroid Build Coastguard Worker     };
409*89c4ff92SAndroid Build Coastguard Worker 
410*89c4ff92SAndroid Build Coastguard Worker     runtime->GetProfiler(netId)->EnableProfiling(true);
411*89c4ff92SAndroid Build Coastguard Worker 
412*89c4ff92SAndroid Build Coastguard Worker     // Do the inference
413*89c4ff92SAndroid Build Coastguard Worker     runtime->EnqueueWorkload(netId, inputTensors, outputTensors);
414*89c4ff92SAndroid Build Coastguard Worker 
415*89c4ff92SAndroid Build Coastguard Worker     // Retrieve the Profiler.Print() output to get the workload execution
416*89c4ff92SAndroid Build Coastguard Worker     ProfilerManager& profilerManager = armnn::ProfilerManager::GetInstance();
417*89c4ff92SAndroid Build Coastguard Worker     std::stringstream ss;
418*89c4ff92SAndroid Build Coastguard Worker     profilerManager.GetProfiler()->Print(ss);;
419*89c4ff92SAndroid Build Coastguard Worker     std::string dump = ss.str();
420*89c4ff92SAndroid Build Coastguard Worker 
421*89c4ff92SAndroid Build Coastguard Worker     // Contains ImportMemGeneric
422*89c4ff92SAndroid Build Coastguard Worker     std::size_t found = dump.find("ImportMemGeneric");
423*89c4ff92SAndroid Build Coastguard Worker     CHECK(found != std::string::npos);
424*89c4ff92SAndroid Build Coastguard Worker 
425*89c4ff92SAndroid Build Coastguard Worker     // Contains SyncMemGeneric
426*89c4ff92SAndroid Build Coastguard Worker     found = dump.find("SyncMemGeneric");
427*89c4ff92SAndroid Build Coastguard Worker     CHECK(found != std::string::npos);
428*89c4ff92SAndroid Build Coastguard Worker 
429*89c4ff92SAndroid Build Coastguard Worker     // Does not contain CopyMemGeneric
430*89c4ff92SAndroid Build Coastguard Worker     found = dump.find("CopyMemGeneric");
431*89c4ff92SAndroid Build Coastguard Worker     CHECK(found == std::string::npos);
432*89c4ff92SAndroid Build Coastguard Worker 
433*89c4ff92SAndroid Build Coastguard Worker     // Use memory import between backends
434*89c4ff92SAndroid Build Coastguard Worker     CHECK((layer4->GetType() == LayerType::MemImport));
435*89c4ff92SAndroid Build Coastguard Worker 
436*89c4ff92SAndroid Build Coastguard Worker     // Check output is as expected
437*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputData == expectedOutput);
438*89c4ff92SAndroid Build Coastguard Worker }
439*89c4ff92SAndroid Build Coastguard Worker 
440*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("FallbackPaddingCopyFromCpuAcc")
441*89c4ff92SAndroid Build Coastguard Worker {
442*89c4ff92SAndroid Build Coastguard Worker     using namespace armnn;
443*89c4ff92SAndroid Build Coastguard Worker 
444*89c4ff92SAndroid Build Coastguard Worker     // Create a mock backend object
445*89c4ff92SAndroid Build Coastguard Worker     MockImportBackendInitialiser initialiser; // Register the Mock Backend
446*89c4ff92SAndroid Build Coastguard Worker     auto backendObjPtr = CreateBackendObject(MockImportBackendId());
447*89c4ff92SAndroid Build Coastguard Worker     CHECK((backendObjPtr != nullptr));
448*89c4ff92SAndroid Build Coastguard Worker 
449*89c4ff92SAndroid Build Coastguard Worker     BackendIdSet backendIds = BackendRegistryInstance().GetBackendIds();
450*89c4ff92SAndroid Build Coastguard Worker     if (backendIds.find("MockRef") == backendIds.end())
451*89c4ff92SAndroid Build Coastguard Worker     {
452*89c4ff92SAndroid Build Coastguard Worker         std::string message = "Cannot load MockRef";
453*89c4ff92SAndroid Build Coastguard Worker         FAIL(message);
454*89c4ff92SAndroid Build Coastguard Worker     }
455*89c4ff92SAndroid Build Coastguard Worker 
456*89c4ff92SAndroid Build Coastguard Worker     // Create runtime in which test will run and allow fallback to CpuRef.
457*89c4ff92SAndroid Build Coastguard Worker     IRuntime::CreationOptions options;
458*89c4ff92SAndroid Build Coastguard Worker     IRuntimePtr runtime(IRuntime::Create(options));
459*89c4ff92SAndroid Build Coastguard Worker 
460*89c4ff92SAndroid Build Coastguard Worker     // Builds up the structure of the network.
461*89c4ff92SAndroid Build Coastguard Worker     INetworkPtr net(INetwork::Create());
462*89c4ff92SAndroid Build Coastguard Worker 
463*89c4ff92SAndroid Build Coastguard Worker     Pooling2dDescriptor desc;
464*89c4ff92SAndroid Build Coastguard Worker 
465*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* input0 = net->AddInputLayer(0, "input0");
466*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* input1 = net->AddInputLayer(1, "input1");
467*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* pooling = net->AddPooling2dLayer(desc, "pooling");
468*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* add = net->AddElementwiseBinaryLayer(BinaryOperation::Add, "add");
469*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* output = net->AddOutputLayer(0, "output");
470*89c4ff92SAndroid Build Coastguard Worker 
471*89c4ff92SAndroid Build Coastguard Worker     input0->GetOutputSlot(0).Connect(pooling->GetInputSlot(0));
472*89c4ff92SAndroid Build Coastguard Worker     input1->GetOutputSlot(0).Connect(add->GetInputSlot(1));
473*89c4ff92SAndroid Build Coastguard Worker     pooling->GetOutputSlot(0).Connect(add->GetInputSlot(0));
474*89c4ff92SAndroid Build Coastguard Worker     add->GetOutputSlot(0).Connect(output->GetInputSlot(0));
475*89c4ff92SAndroid Build Coastguard Worker 
476*89c4ff92SAndroid Build Coastguard Worker     TensorInfo inputInfo = TensorInfo({ 1, 2, 3, 2 }, DataType::Float32);
477*89c4ff92SAndroid Build Coastguard Worker     TensorInfo poolingInfo = TensorInfo({ 1, 2, 1, 1 }, DataType::Float32);
478*89c4ff92SAndroid Build Coastguard Worker 
479*89c4ff92SAndroid Build Coastguard Worker     input0->GetOutputSlot(0).SetTensorInfo(inputInfo);
480*89c4ff92SAndroid Build Coastguard Worker     input1->GetOutputSlot(0).SetTensorInfo(poolingInfo);
481*89c4ff92SAndroid Build Coastguard Worker     pooling->GetOutputSlot(0).SetTensorInfo(poolingInfo);
482*89c4ff92SAndroid Build Coastguard Worker     add->GetOutputSlot(0).SetTensorInfo(poolingInfo);
483*89c4ff92SAndroid Build Coastguard Worker 
484*89c4ff92SAndroid Build Coastguard Worker     // optimize the network
485*89c4ff92SAndroid Build Coastguard Worker     std::vector<BackendId> backends = { "MockRef", Compute::CpuAcc };
486*89c4ff92SAndroid Build Coastguard Worker     OptimizerOptionsOpaque optOptions;
487*89c4ff92SAndroid Build Coastguard Worker     optOptions.SetImportEnabled(true);
488*89c4ff92SAndroid Build Coastguard Worker     optOptions.SetExportEnabled(true);
489*89c4ff92SAndroid Build Coastguard Worker     IOptimizedNetworkPtr optNet = Optimize(*net, backends, runtime->GetDeviceSpec(), optOptions);
490*89c4ff92SAndroid Build Coastguard Worker 
491*89c4ff92SAndroid Build Coastguard Worker     Graph& graph = GetGraphForTesting(optNet.get());
492*89c4ff92SAndroid Build Coastguard Worker 
493*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer0 = GetFirstLayerWithName(graph, "input0");
494*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer1 = GetFirstLayerWithName(graph, "input1");
495*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer2 = GetFirstLayerWithName(graph, "pooling");
496*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer3 = GetFirstLayerWithName(graph, "[ pooling (0) -> add (0) ]");
497*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer4 = GetFirstLayerWithName(graph, "add");
498*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer5 = GetFirstLayerWithName(graph, "output");
499*89c4ff92SAndroid Build Coastguard Worker 
500*89c4ff92SAndroid Build Coastguard Worker     // Checks order is valid.
501*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer0, layer1));
502*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer1, layer2));
503*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer2, layer3));
504*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer3, layer4));
505*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer4, layer5));
506*89c4ff92SAndroid Build Coastguard Worker 
507*89c4ff92SAndroid Build Coastguard Worker     // Load it into the runtime. It should pass.
508*89c4ff92SAndroid Build Coastguard Worker     NetworkId netId;
509*89c4ff92SAndroid Build Coastguard Worker     std::string ignoredErrorMessage;
510*89c4ff92SAndroid Build Coastguard Worker     INetworkProperties networkProperties(false, MemorySource::Malloc, MemorySource::Malloc);
511*89c4ff92SAndroid Build Coastguard Worker 
512*89c4ff92SAndroid Build Coastguard Worker     runtime->LoadNetwork(netId, std::move(optNet), ignoredErrorMessage, networkProperties);
513*89c4ff92SAndroid Build Coastguard Worker 
514*89c4ff92SAndroid Build Coastguard Worker     // Creates structures for input & output
515*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputData0
516*89c4ff92SAndroid Build Coastguard Worker     {
517*89c4ff92SAndroid Build Coastguard Worker         1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f
518*89c4ff92SAndroid Build Coastguard Worker     };
519*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputData1
520*89c4ff92SAndroid Build Coastguard Worker     {
521*89c4ff92SAndroid Build Coastguard Worker         -1.0f, 3.0f
522*89c4ff92SAndroid Build Coastguard Worker     };
523*89c4ff92SAndroid Build Coastguard Worker 
524*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> outputData(2);
525*89c4ff92SAndroid Build Coastguard Worker 
526*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expectedOutput
527*89c4ff92SAndroid Build Coastguard Worker     {
528*89c4ff92SAndroid Build Coastguard Worker         5.0f, 15.0f
529*89c4ff92SAndroid Build Coastguard Worker     };
530*89c4ff92SAndroid Build Coastguard Worker 
531*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo0 = runtime->GetInputTensorInfo(netId, 0);
532*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo1 = runtime->GetInputTensorInfo(netId, 1);
533*89c4ff92SAndroid Build Coastguard Worker     inputTensorInfo0.SetConstant(true);
534*89c4ff92SAndroid Build Coastguard Worker     inputTensorInfo1.SetConstant(true);
535*89c4ff92SAndroid Build Coastguard Worker 
536*89c4ff92SAndroid Build Coastguard Worker     InputTensors inputTensors
537*89c4ff92SAndroid Build Coastguard Worker     {
538*89c4ff92SAndroid Build Coastguard Worker         { 0, armnn::ConstTensor(inputTensorInfo0, inputData0.data()) },
539*89c4ff92SAndroid Build Coastguard Worker         { 1, armnn::ConstTensor(inputTensorInfo1, inputData1.data()) }
540*89c4ff92SAndroid Build Coastguard Worker     };
541*89c4ff92SAndroid Build Coastguard Worker     OutputTensors outputTensors
542*89c4ff92SAndroid Build Coastguard Worker     {
543*89c4ff92SAndroid Build Coastguard Worker         { 0, armnn::Tensor(runtime->GetOutputTensorInfo(netId, 0), outputData.data()) }
544*89c4ff92SAndroid Build Coastguard Worker     };
545*89c4ff92SAndroid Build Coastguard Worker 
546*89c4ff92SAndroid Build Coastguard Worker     runtime->GetProfiler(netId)->EnableProfiling(true);
547*89c4ff92SAndroid Build Coastguard Worker 
548*89c4ff92SAndroid Build Coastguard Worker     // Do the inference
549*89c4ff92SAndroid Build Coastguard Worker     runtime->EnqueueWorkload(netId, inputTensors, outputTensors);
550*89c4ff92SAndroid Build Coastguard Worker 
551*89c4ff92SAndroid Build Coastguard Worker     // Retrieve the Profiler.Print() output to get the workload execution
552*89c4ff92SAndroid Build Coastguard Worker     ProfilerManager& profilerManager = armnn::ProfilerManager::GetInstance();
553*89c4ff92SAndroid Build Coastguard Worker     std::stringstream ss;
554*89c4ff92SAndroid Build Coastguard Worker     profilerManager.GetProfiler()->Print(ss);;
555*89c4ff92SAndroid Build Coastguard Worker     std::string dump = ss.str();
556*89c4ff92SAndroid Build Coastguard Worker 
557*89c4ff92SAndroid Build Coastguard Worker     // Contains CopyMemGeneric between the backends
558*89c4ff92SAndroid Build Coastguard Worker     std::size_t found = dump.find("CopyMemGeneric");
559*89c4ff92SAndroid Build Coastguard Worker     CHECK(found != std::string::npos);
560*89c4ff92SAndroid Build Coastguard Worker 
561*89c4ff92SAndroid Build Coastguard Worker     // Contains SyncMemGeneric for the output
562*89c4ff92SAndroid Build Coastguard Worker     found = dump.find("SyncMemGeneric");
563*89c4ff92SAndroid Build Coastguard Worker     CHECK(found != std::string::npos);
564*89c4ff92SAndroid Build Coastguard Worker 
565*89c4ff92SAndroid Build Coastguard Worker     // Does not contain ImportMemGeneric
566*89c4ff92SAndroid Build Coastguard Worker     found = dump.find("ImportMemGeneric");
567*89c4ff92SAndroid Build Coastguard Worker     CHECK(found == std::string::npos);
568*89c4ff92SAndroid Build Coastguard Worker 
569*89c4ff92SAndroid Build Coastguard Worker     // Use memory import between backends
570*89c4ff92SAndroid Build Coastguard Worker     CHECK((layer3->GetType() == LayerType::MemCopy));
571*89c4ff92SAndroid Build Coastguard Worker 
572*89c4ff92SAndroid Build Coastguard Worker     // Check output is as expected
573*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputData == expectedOutput);
574*89c4ff92SAndroid Build Coastguard Worker }
575*89c4ff92SAndroid Build Coastguard Worker 
576*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("FallbackDisableImportFromCpuAcc")
577*89c4ff92SAndroid Build Coastguard Worker {
578*89c4ff92SAndroid Build Coastguard Worker     using namespace armnn;
579*89c4ff92SAndroid Build Coastguard Worker 
580*89c4ff92SAndroid Build Coastguard Worker     // Create a mock backend object
581*89c4ff92SAndroid Build Coastguard Worker     MockImportBackendInitialiser initialiser; // Register the Mock Backend
582*89c4ff92SAndroid Build Coastguard Worker     auto backendObjPtr = CreateBackendObject(MockImportBackendId());
583*89c4ff92SAndroid Build Coastguard Worker     CHECK((backendObjPtr != nullptr));
584*89c4ff92SAndroid Build Coastguard Worker 
585*89c4ff92SAndroid Build Coastguard Worker     BackendIdSet backendIds = BackendRegistryInstance().GetBackendIds();
586*89c4ff92SAndroid Build Coastguard Worker     if (backendIds.find("MockRef") == backendIds.end())
587*89c4ff92SAndroid Build Coastguard Worker     {
588*89c4ff92SAndroid Build Coastguard Worker         std::string message = "Cannot load MockRef";
589*89c4ff92SAndroid Build Coastguard Worker         FAIL(message);
590*89c4ff92SAndroid Build Coastguard Worker     }
591*89c4ff92SAndroid Build Coastguard Worker 
592*89c4ff92SAndroid Build Coastguard Worker     // Create runtime in which test will run and allow fallback to CpuRef.
593*89c4ff92SAndroid Build Coastguard Worker     IRuntime::CreationOptions options;
594*89c4ff92SAndroid Build Coastguard Worker     IRuntimePtr runtime(IRuntime::Create(options));
595*89c4ff92SAndroid Build Coastguard Worker 
596*89c4ff92SAndroid Build Coastguard Worker     // Builds up the structure of the network.
597*89c4ff92SAndroid Build Coastguard Worker     INetworkPtr net(INetwork::Create());
598*89c4ff92SAndroid Build Coastguard Worker 
599*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* input0 = net->AddInputLayer(0, "input0");
600*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* input1 = net->AddInputLayer(1, "input1");
601*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* input2 = net->AddInputLayer(2, "input2");
602*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* sub = net->AddElementwiseBinaryLayer(BinaryOperation::Sub, "sub");
603*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* add = net->AddElementwiseBinaryLayer(BinaryOperation::Add, "add");
604*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* output = net->AddOutputLayer(0, "output");
605*89c4ff92SAndroid Build Coastguard Worker 
606*89c4ff92SAndroid Build Coastguard Worker     input0->GetOutputSlot(0).Connect(sub->GetInputSlot(0));
607*89c4ff92SAndroid Build Coastguard Worker     input1->GetOutputSlot(0).Connect(sub->GetInputSlot(1));
608*89c4ff92SAndroid Build Coastguard Worker     input2->GetOutputSlot(0).Connect(add->GetInputSlot(0));
609*89c4ff92SAndroid Build Coastguard Worker     sub->GetOutputSlot(0).Connect(add->GetInputSlot(1));
610*89c4ff92SAndroid Build Coastguard Worker     add->GetOutputSlot(0).Connect(output->GetInputSlot(0));
611*89c4ff92SAndroid Build Coastguard Worker 
612*89c4ff92SAndroid Build Coastguard Worker     TensorInfo info = TensorInfo({ 1, 2, 3, 2 }, DataType::Float32);
613*89c4ff92SAndroid Build Coastguard Worker 
614*89c4ff92SAndroid Build Coastguard Worker     input0->GetOutputSlot(0).SetTensorInfo(info);
615*89c4ff92SAndroid Build Coastguard Worker     input1->GetOutputSlot(0).SetTensorInfo(info);
616*89c4ff92SAndroid Build Coastguard Worker     input2->GetOutputSlot(0).SetTensorInfo(info);
617*89c4ff92SAndroid Build Coastguard Worker     sub->GetOutputSlot(0).SetTensorInfo(info);
618*89c4ff92SAndroid Build Coastguard Worker     add->GetOutputSlot(0).SetTensorInfo(info);
619*89c4ff92SAndroid Build Coastguard Worker 
620*89c4ff92SAndroid Build Coastguard Worker     // optimize the network
621*89c4ff92SAndroid Build Coastguard Worker     std::vector<BackendId> backends = { "MockRef", Compute::CpuAcc };
622*89c4ff92SAndroid Build Coastguard Worker     IOptimizedNetworkPtr optNet = Optimize(*net, backends, runtime->GetDeviceSpec());
623*89c4ff92SAndroid Build Coastguard Worker 
624*89c4ff92SAndroid Build Coastguard Worker     Graph& graph = GetGraphForTesting(optNet.get());
625*89c4ff92SAndroid Build Coastguard Worker 
626*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer0 = GetFirstLayerWithName(graph, "input0");
627*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer1 = GetFirstLayerWithName(graph, "input1");
628*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer2 = GetFirstLayerWithName(graph, "input2");
629*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer3 = GetFirstLayerWithName(graph, "sub");
630*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer4 = GetFirstLayerWithName(graph, "[ sub (0) -> add (1) ]");
631*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer5 = GetFirstLayerWithName(graph, "add");
632*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer6 = GetFirstLayerWithName(graph, "output");
633*89c4ff92SAndroid Build Coastguard Worker 
634*89c4ff92SAndroid Build Coastguard Worker     // Checks order is valid.
635*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer0, layer1));
636*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer1, layer2));
637*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer2, layer3));
638*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer3, layer4));
639*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer4, layer5));
640*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer5, layer6));
641*89c4ff92SAndroid Build Coastguard Worker 
642*89c4ff92SAndroid Build Coastguard Worker     // Load it into the runtime. It should pass.
643*89c4ff92SAndroid Build Coastguard Worker     NetworkId netId;
644*89c4ff92SAndroid Build Coastguard Worker     std::string ignoredErrorMessage;
645*89c4ff92SAndroid Build Coastguard Worker     INetworkProperties networkProperties(false, MemorySource::Undefined, MemorySource::Undefined);
646*89c4ff92SAndroid Build Coastguard Worker 
647*89c4ff92SAndroid Build Coastguard Worker     runtime->LoadNetwork(netId, std::move(optNet), ignoredErrorMessage, networkProperties);
648*89c4ff92SAndroid Build Coastguard Worker 
649*89c4ff92SAndroid Build Coastguard Worker     // Creates structures for input & output
650*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputData0
651*89c4ff92SAndroid Build Coastguard Worker     {
652*89c4ff92SAndroid Build Coastguard Worker         1.0f, 1.0f, 2.0f, 2.0f, 2.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f, 6.0f, 0.0f
653*89c4ff92SAndroid Build Coastguard Worker     };
654*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputData1
655*89c4ff92SAndroid Build Coastguard Worker     {
656*89c4ff92SAndroid Build Coastguard Worker         0.0f, 1.0f, 1.0f, 2.0f, 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f, 6.0f
657*89c4ff92SAndroid Build Coastguard Worker     };
658*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputData2
659*89c4ff92SAndroid Build Coastguard Worker     {
660*89c4ff92SAndroid Build Coastguard Worker         12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f
661*89c4ff92SAndroid Build Coastguard Worker     };
662*89c4ff92SAndroid Build Coastguard Worker 
663*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> outputData(12);
664*89c4ff92SAndroid Build Coastguard Worker 
665*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expectedOutput
666*89c4ff92SAndroid Build Coastguard Worker     {
667*89c4ff92SAndroid Build Coastguard Worker         13.0f, 11.0f, 11.0f, 9.0f, 7.0f, 7.0f, 7.0f, 5.0f, 5.0f, 3.0f, 3.0f, -5.0f
668*89c4ff92SAndroid Build Coastguard Worker     };
669*89c4ff92SAndroid Build Coastguard Worker 
670*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo0 = runtime->GetInputTensorInfo(netId, 0);
671*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo1 = runtime->GetInputTensorInfo(netId, 1);
672*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo2 = runtime->GetInputTensorInfo(netId, 2);
673*89c4ff92SAndroid Build Coastguard Worker     inputTensorInfo0.SetConstant(true);
674*89c4ff92SAndroid Build Coastguard Worker     inputTensorInfo1.SetConstant(true);
675*89c4ff92SAndroid Build Coastguard Worker     inputTensorInfo2.SetConstant(true);
676*89c4ff92SAndroid Build Coastguard Worker 
677*89c4ff92SAndroid Build Coastguard Worker     InputTensors inputTensors
678*89c4ff92SAndroid Build Coastguard Worker     {
679*89c4ff92SAndroid Build Coastguard Worker         { 0, armnn::ConstTensor(inputTensorInfo0, inputData0.data()) },
680*89c4ff92SAndroid Build Coastguard Worker         { 1, armnn::ConstTensor(inputTensorInfo1, inputData1.data()) },
681*89c4ff92SAndroid Build Coastguard Worker         { 2, armnn::ConstTensor(inputTensorInfo2, inputData2.data()) }
682*89c4ff92SAndroid Build Coastguard Worker     };
683*89c4ff92SAndroid Build Coastguard Worker     OutputTensors outputTensors
684*89c4ff92SAndroid Build Coastguard Worker     {
685*89c4ff92SAndroid Build Coastguard Worker         { 0,armnn::Tensor(runtime->GetOutputTensorInfo(netId, 0), outputData.data()) }
686*89c4ff92SAndroid Build Coastguard Worker     };
687*89c4ff92SAndroid Build Coastguard Worker 
688*89c4ff92SAndroid Build Coastguard Worker     runtime->GetProfiler(netId)->EnableProfiling(true);
689*89c4ff92SAndroid Build Coastguard Worker 
690*89c4ff92SAndroid Build Coastguard Worker     // Do the inference
691*89c4ff92SAndroid Build Coastguard Worker     runtime->EnqueueWorkload(netId, inputTensors, outputTensors);
692*89c4ff92SAndroid Build Coastguard Worker 
693*89c4ff92SAndroid Build Coastguard Worker     // Retrieve the Profiler.Print() output to get the workload execution
694*89c4ff92SAndroid Build Coastguard Worker     ProfilerManager& profilerManager = armnn::ProfilerManager::GetInstance();
695*89c4ff92SAndroid Build Coastguard Worker     std::stringstream ss;
696*89c4ff92SAndroid Build Coastguard Worker     profilerManager.GetProfiler()->Print(ss);;
697*89c4ff92SAndroid Build Coastguard Worker     std::string dump = ss.str();
698*89c4ff92SAndroid Build Coastguard Worker 
699*89c4ff92SAndroid Build Coastguard Worker     // Contains CopyMemGeneric between the backends
700*89c4ff92SAndroid Build Coastguard Worker     std::size_t found = dump.find("CopyMemGeneric");
701*89c4ff92SAndroid Build Coastguard Worker     CHECK(found != std::string::npos);
702*89c4ff92SAndroid Build Coastguard Worker 
703*89c4ff92SAndroid Build Coastguard Worker     // Does not contain ImportMemGeneric
704*89c4ff92SAndroid Build Coastguard Worker     found = dump.find("ImportMemGeneric");
705*89c4ff92SAndroid Build Coastguard Worker     CHECK(found == std::string::npos);
706*89c4ff92SAndroid Build Coastguard Worker 
707*89c4ff92SAndroid Build Coastguard Worker     // Use memory import between backends
708*89c4ff92SAndroid Build Coastguard Worker     CHECK((layer4->GetType() == LayerType::MemCopy));
709*89c4ff92SAndroid Build Coastguard Worker 
710*89c4ff92SAndroid Build Coastguard Worker     // Check output is as expected
711*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputData == expectedOutput);
712*89c4ff92SAndroid Build Coastguard Worker }
713*89c4ff92SAndroid Build Coastguard Worker 
714*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMCOMPUTECL_ENABLED)
715*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("NeonImportEnabledFallbackToCl")
716*89c4ff92SAndroid Build Coastguard Worker {
717*89c4ff92SAndroid Build Coastguard Worker     using namespace armnn;
718*89c4ff92SAndroid Build Coastguard Worker 
719*89c4ff92SAndroid Build Coastguard Worker     IRuntime::CreationOptions options;
720*89c4ff92SAndroid Build Coastguard Worker     IRuntimePtr runtime(IRuntime::Create(options));
721*89c4ff92SAndroid Build Coastguard Worker 
722*89c4ff92SAndroid Build Coastguard Worker     // Builds up the structure of the network.
723*89c4ff92SAndroid Build Coastguard Worker     INetworkPtr net(INetwork::Create());
724*89c4ff92SAndroid Build Coastguard Worker 
725*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* input0 = net->AddInputLayer(0, "input0");
726*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* input1 = net->AddInputLayer(1, "input1");
727*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* input2 = net->AddInputLayer(2, "input2");
728*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* add = net->AddElementwiseBinaryLayer(BinaryOperation::Add, "add");
729*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* sub = net->AddElementwiseBinaryLayer(BinaryOperation::Sub, "sub");
730*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* output = net->AddOutputLayer(0, "output");
731*89c4ff92SAndroid Build Coastguard Worker 
732*89c4ff92SAndroid Build Coastguard Worker     input0->GetOutputSlot(0).Connect(add->GetInputSlot(0));
733*89c4ff92SAndroid Build Coastguard Worker     input1->GetOutputSlot(0).Connect(add->GetInputSlot(1));
734*89c4ff92SAndroid Build Coastguard Worker     input2->GetOutputSlot(0).Connect(sub->GetInputSlot(0));
735*89c4ff92SAndroid Build Coastguard Worker     add->GetOutputSlot(0).Connect(sub->GetInputSlot(1));
736*89c4ff92SAndroid Build Coastguard Worker     sub->GetOutputSlot(0).Connect(output->GetInputSlot(0));
737*89c4ff92SAndroid Build Coastguard Worker 
738*89c4ff92SAndroid Build Coastguard Worker     TensorInfo info = TensorInfo({ 1, 2, 4, 2 }, DataType::Float32);
739*89c4ff92SAndroid Build Coastguard Worker 
740*89c4ff92SAndroid Build Coastguard Worker     input0->GetOutputSlot(0).SetTensorInfo(info);
741*89c4ff92SAndroid Build Coastguard Worker     input1->GetOutputSlot(0).SetTensorInfo(info);
742*89c4ff92SAndroid Build Coastguard Worker     input2->GetOutputSlot(0).SetTensorInfo(info);
743*89c4ff92SAndroid Build Coastguard Worker     add->GetOutputSlot(0).SetTensorInfo(info);
744*89c4ff92SAndroid Build Coastguard Worker     sub->GetOutputSlot(0).SetTensorInfo(info);
745*89c4ff92SAndroid Build Coastguard Worker 
746*89c4ff92SAndroid Build Coastguard Worker     std::vector<BackendId> backends = { Compute::CpuAcc, Compute::GpuAcc };
747*89c4ff92SAndroid Build Coastguard Worker     // Use BackendSelectionHint to specify GpuAcc for Subtraction layer
748*89c4ff92SAndroid Build Coastguard Worker     sub->BackendSelectionHint(backends[1]);
749*89c4ff92SAndroid Build Coastguard Worker 
750*89c4ff92SAndroid Build Coastguard Worker     // optimize the network
751*89c4ff92SAndroid Build Coastguard Worker     OptimizerOptionsOpaque optOptions;
752*89c4ff92SAndroid Build Coastguard Worker     optOptions.SetImportEnabled(true);
753*89c4ff92SAndroid Build Coastguard Worker     optOptions.SetExportEnabled(true);
754*89c4ff92SAndroid Build Coastguard Worker     IOptimizedNetworkPtr optNet = Optimize(*net, backends, runtime->GetDeviceSpec(), optOptions);
755*89c4ff92SAndroid Build Coastguard Worker 
756*89c4ff92SAndroid Build Coastguard Worker     Graph& graph = GetGraphForTesting(optNet.get());
757*89c4ff92SAndroid Build Coastguard Worker 
758*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer0 = GetFirstLayerWithName(graph, "input0");
759*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer1 = GetFirstLayerWithName(graph, "input1");
760*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer2 = GetFirstLayerWithName(graph, "input2");
761*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer3 = GetFirstLayerWithName(graph, "add");
762*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer4 = GetFirstLayerWithName(graph, "[ add (0) -> sub (1) ]");
763*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer5 = GetFirstLayerWithName(graph, "sub");
764*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer6 = GetFirstLayerWithName(graph, "output");
765*89c4ff92SAndroid Build Coastguard Worker 
766*89c4ff92SAndroid Build Coastguard Worker     // Checks order is valid.
767*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer0, layer1));
768*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer1, layer2));
769*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer2, layer3));
770*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer3, layer4));
771*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer4, layer5));
772*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer5, layer6));
773*89c4ff92SAndroid Build Coastguard Worker 
774*89c4ff92SAndroid Build Coastguard Worker     // Use memory import between backends
775*89c4ff92SAndroid Build Coastguard Worker     CHECK((layer4->GetType() == LayerType::MemCopy));
776*89c4ff92SAndroid Build Coastguard Worker 
777*89c4ff92SAndroid Build Coastguard Worker     // Correctly use backend hint
778*89c4ff92SAndroid Build Coastguard Worker     CHECK((layer5->GetBackendId() == Compute::GpuAcc ));
779*89c4ff92SAndroid Build Coastguard Worker 
780*89c4ff92SAndroid Build Coastguard Worker     // Load it into the runtime. It should pass.
781*89c4ff92SAndroid Build Coastguard Worker     NetworkId netId;
782*89c4ff92SAndroid Build Coastguard Worker     std::string ignoredErrorMessage;
783*89c4ff92SAndroid Build Coastguard Worker 
784*89c4ff92SAndroid Build Coastguard Worker     INetworkProperties networkProperties(false, MemorySource::Malloc, MemorySource::Malloc);
785*89c4ff92SAndroid Build Coastguard Worker 
786*89c4ff92SAndroid Build Coastguard Worker     runtime->LoadNetwork(netId, std::move(optNet), ignoredErrorMessage, networkProperties);
787*89c4ff92SAndroid Build Coastguard Worker 
788*89c4ff92SAndroid Build Coastguard Worker     // Creates structures for input & output
789*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputData0
790*89c4ff92SAndroid Build Coastguard Worker     {
791*89c4ff92SAndroid Build Coastguard Worker         1.0f, 1.0f, 2.0f, 2.0f, 2.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f, 6.0f, 6.0f, 1.0f, 1.0f, 2.0f, 2.0f
792*89c4ff92SAndroid Build Coastguard Worker     };
793*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputData1
794*89c4ff92SAndroid Build Coastguard Worker     {
795*89c4ff92SAndroid Build Coastguard Worker         0.0f, 1.0f, 1.0f, 2.0f, 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f, 6.0f, 0.0f, 1.0f, 1.0f, 2.0f
796*89c4ff92SAndroid Build Coastguard Worker     };
797*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputData2
798*89c4ff92SAndroid Build Coastguard Worker     {
799*89c4ff92SAndroid Build Coastguard Worker         12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f, 12.0f, 11.0f, 10.0f, 9.0f
800*89c4ff92SAndroid Build Coastguard Worker     };
801*89c4ff92SAndroid Build Coastguard Worker 
802*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> outputData(16);
803*89c4ff92SAndroid Build Coastguard Worker 
804*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expectedOutput
805*89c4ff92SAndroid Build Coastguard Worker     {
806*89c4ff92SAndroid Build Coastguard Worker         11.0f, 9.0f, 7.0f, 5.0f, 3.0f, 1.0f, -1.0f, -3.0f, -5.0f, -7.0f, -9.0f, -11.0f, 11.0f, 9.0f, 7.0f, 5.0f
807*89c4ff92SAndroid Build Coastguard Worker     };
808*89c4ff92SAndroid Build Coastguard Worker 
809*89c4ff92SAndroid Build Coastguard Worker     // Creates structures for input & output
810*89c4ff92SAndroid Build Coastguard Worker     unsigned int numElements = info.GetNumElements();
811*89c4ff92SAndroid Build Coastguard Worker     size_t totalBytes = numElements * sizeof(float);
812*89c4ff92SAndroid Build Coastguard Worker 
813*89c4ff92SAndroid Build Coastguard Worker     // Prepare aligned data
814*89c4ff92SAndroid Build Coastguard Worker     const size_t alignment = 64;
815*89c4ff92SAndroid Build Coastguard Worker     size_t space = totalBytes + alignment + alignment;
816*89c4ff92SAndroid Build Coastguard Worker     auto inputData = std::make_unique<uint8_t[]>(space);
817*89c4ff92SAndroid Build Coastguard Worker     void* alignedInputPtr = inputData.get();
818*89c4ff92SAndroid Build Coastguard Worker     CHECK(std::align(alignment, totalBytes, alignedInputPtr, space));
819*89c4ff92SAndroid Build Coastguard Worker 
820*89c4ff92SAndroid Build Coastguard Worker     auto* intputPtr = reinterpret_cast<float*>(alignedInputPtr);
821*89c4ff92SAndroid Build Coastguard Worker     std::copy(inputData2.begin(), inputData2.end(), intputPtr);
822*89c4ff92SAndroid Build Coastguard Worker 
823*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo0 = runtime->GetInputTensorInfo(netId, 0);
824*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo1 = runtime->GetInputTensorInfo(netId, 1);
825*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo2 = runtime->GetInputTensorInfo(netId, 2);
826*89c4ff92SAndroid Build Coastguard Worker     inputTensorInfo0.SetConstant(true);
827*89c4ff92SAndroid Build Coastguard Worker     inputTensorInfo1.SetConstant(true);
828*89c4ff92SAndroid Build Coastguard Worker     inputTensorInfo2.SetConstant(true);
829*89c4ff92SAndroid Build Coastguard Worker 
830*89c4ff92SAndroid Build Coastguard Worker     InputTensors inputTensors
831*89c4ff92SAndroid Build Coastguard Worker     {
832*89c4ff92SAndroid Build Coastguard Worker         { 0, armnn::ConstTensor(inputTensorInfo0, inputData0.data()) },
833*89c4ff92SAndroid Build Coastguard Worker         { 1, armnn::ConstTensor(inputTensorInfo1, inputData1.data()) },
834*89c4ff92SAndroid Build Coastguard Worker         { 2, armnn::ConstTensor(inputTensorInfo2, alignedInputPtr) }
835*89c4ff92SAndroid Build Coastguard Worker     };
836*89c4ff92SAndroid Build Coastguard Worker     OutputTensors outputTensors
837*89c4ff92SAndroid Build Coastguard Worker     {
838*89c4ff92SAndroid Build Coastguard Worker         { 0,armnn::Tensor(runtime->GetOutputTensorInfo(netId, 0), outputData.data()) }
839*89c4ff92SAndroid Build Coastguard Worker     };
840*89c4ff92SAndroid Build Coastguard Worker 
841*89c4ff92SAndroid Build Coastguard Worker     runtime->GetProfiler(netId)->EnableProfiling(true);
842*89c4ff92SAndroid Build Coastguard Worker 
843*89c4ff92SAndroid Build Coastguard Worker     // Do the inference
844*89c4ff92SAndroid Build Coastguard Worker     runtime->EnqueueWorkload(netId, inputTensors, outputTensors);
845*89c4ff92SAndroid Build Coastguard Worker 
846*89c4ff92SAndroid Build Coastguard Worker     // Retrieve the Profiler.Print() output to get the workload execution
847*89c4ff92SAndroid Build Coastguard Worker     ProfilerManager& profilerManager = armnn::ProfilerManager::GetInstance();
848*89c4ff92SAndroid Build Coastguard Worker     std::stringstream ss;
849*89c4ff92SAndroid Build Coastguard Worker     profilerManager.GetProfiler()->Print(ss);;
850*89c4ff92SAndroid Build Coastguard Worker     std::string dump = ss.str();
851*89c4ff92SAndroid Build Coastguard Worker 
852*89c4ff92SAndroid Build Coastguard Worker     // Executed Subtraction using GpuAcc
853*89c4ff92SAndroid Build Coastguard Worker     std::size_t found = dump.find("ClSubtractionWorkload_Execute");
854*89c4ff92SAndroid Build Coastguard Worker     CHECK(found != std::string::npos);
855*89c4ff92SAndroid Build Coastguard Worker 
856*89c4ff92SAndroid Build Coastguard Worker     // Contain CopyMemGeneric
857*89c4ff92SAndroid Build Coastguard Worker     found = dump.find("CopyMemGeneric");
858*89c4ff92SAndroid Build Coastguard Worker     CHECK(found != std::string::npos);
859*89c4ff92SAndroid Build Coastguard Worker 
860*89c4ff92SAndroid Build Coastguard Worker     // Check output is as expected
861*89c4ff92SAndroid Build Coastguard Worker     for(unsigned int i = 0; i < numElements; ++i)
862*89c4ff92SAndroid Build Coastguard Worker     {
863*89c4ff92SAndroid Build Coastguard Worker         CHECK(outputData[i] == expectedOutput[i]);
864*89c4ff92SAndroid Build Coastguard Worker     }
865*89c4ff92SAndroid Build Coastguard Worker     runtime->UnloadNetwork(netId);
866*89c4ff92SAndroid Build Coastguard Worker }
867*89c4ff92SAndroid Build Coastguard Worker 
868*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("NeonImportDisabledFallbackToCl")
869*89c4ff92SAndroid Build Coastguard Worker {
870*89c4ff92SAndroid Build Coastguard Worker     using namespace armnn;
871*89c4ff92SAndroid Build Coastguard Worker 
872*89c4ff92SAndroid Build Coastguard Worker     IRuntime::CreationOptions options;
873*89c4ff92SAndroid Build Coastguard Worker     IRuntimePtr runtime(IRuntime::Create(options));
874*89c4ff92SAndroid Build Coastguard Worker 
875*89c4ff92SAndroid Build Coastguard Worker     // Builds up the structure of the network.
876*89c4ff92SAndroid Build Coastguard Worker     INetworkPtr net(INetwork::Create());
877*89c4ff92SAndroid Build Coastguard Worker 
878*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* input0 = net->AddInputLayer(0, "input0");
879*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* input1 = net->AddInputLayer(1, "input1");
880*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* input2 = net->AddInputLayer(2, "input2");
881*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* add = net->AddElementwiseBinaryLayer(BinaryOperation::Add, "add");
882*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* sub = net->AddElementwiseBinaryLayer(BinaryOperation::Sub, "sub");
883*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* output = net->AddOutputLayer(0, "output");
884*89c4ff92SAndroid Build Coastguard Worker 
885*89c4ff92SAndroid Build Coastguard Worker     input0->GetOutputSlot(0).Connect(add->GetInputSlot(0));
886*89c4ff92SAndroid Build Coastguard Worker     input1->GetOutputSlot(0).Connect(add->GetInputSlot(1));
887*89c4ff92SAndroid Build Coastguard Worker     input2->GetOutputSlot(0).Connect(sub->GetInputSlot(0));
888*89c4ff92SAndroid Build Coastguard Worker     add->GetOutputSlot(0).Connect(sub->GetInputSlot(1));
889*89c4ff92SAndroid Build Coastguard Worker     sub->GetOutputSlot(0).Connect(output->GetInputSlot(0));
890*89c4ff92SAndroid Build Coastguard Worker 
891*89c4ff92SAndroid Build Coastguard Worker     TensorInfo info = TensorInfo({ 1, 2, 3, 2 }, DataType::Float32);
892*89c4ff92SAndroid Build Coastguard Worker 
893*89c4ff92SAndroid Build Coastguard Worker     input0->GetOutputSlot(0).SetTensorInfo(info);
894*89c4ff92SAndroid Build Coastguard Worker     input1->GetOutputSlot(0).SetTensorInfo(info);
895*89c4ff92SAndroid Build Coastguard Worker     input2->GetOutputSlot(0).SetTensorInfo(info);
896*89c4ff92SAndroid Build Coastguard Worker     add->GetOutputSlot(0).SetTensorInfo(info);
897*89c4ff92SAndroid Build Coastguard Worker     sub->GetOutputSlot(0).SetTensorInfo(info);
898*89c4ff92SAndroid Build Coastguard Worker 
899*89c4ff92SAndroid Build Coastguard Worker     std::vector<BackendId> backends = { Compute::CpuAcc, Compute::GpuAcc };
900*89c4ff92SAndroid Build Coastguard Worker     // Use BackendSelectionHint to specify GpuAcc for Subtraction layer
901*89c4ff92SAndroid Build Coastguard Worker     sub->BackendSelectionHint(backends[1]);
902*89c4ff92SAndroid Build Coastguard Worker 
903*89c4ff92SAndroid Build Coastguard Worker     // optimize the network
904*89c4ff92SAndroid Build Coastguard Worker     OptimizerOptionsOpaque optOptions;
905*89c4ff92SAndroid Build Coastguard Worker     IOptimizedNetworkPtr optNet = Optimize(*net, backends, runtime->GetDeviceSpec(), optOptions);
906*89c4ff92SAndroid Build Coastguard Worker 
907*89c4ff92SAndroid Build Coastguard Worker     Graph& graph = GetGraphForTesting(optNet.get());
908*89c4ff92SAndroid Build Coastguard Worker 
909*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer0 = GetFirstLayerWithName(graph, "input0");
910*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer1 = GetFirstLayerWithName(graph, "input1");
911*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer2 = GetFirstLayerWithName(graph, "input2");
912*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer3 = GetFirstLayerWithName(graph, "add");
913*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer4 = GetFirstLayerWithName(graph, "[ add (0) -> sub (1) ]");
914*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer5 = GetFirstLayerWithName(graph, "sub");
915*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer6 = GetFirstLayerWithName(graph, "output");
916*89c4ff92SAndroid Build Coastguard Worker 
917*89c4ff92SAndroid Build Coastguard Worker     // Checks order is valid.
918*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer0, layer1));
919*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer1, layer2));
920*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer2, layer3));
921*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer3, layer4));
922*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer4, layer5));
923*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer5, layer6));
924*89c4ff92SAndroid Build Coastguard Worker 
925*89c4ff92SAndroid Build Coastguard Worker     // Use memory import between backends
926*89c4ff92SAndroid Build Coastguard Worker     CHECK((layer4->GetType() == LayerType::MemCopy));
927*89c4ff92SAndroid Build Coastguard Worker 
928*89c4ff92SAndroid Build Coastguard Worker     // Correctly use backend hint
929*89c4ff92SAndroid Build Coastguard Worker     CHECK((layer5->GetBackendId() == Compute::GpuAcc ));
930*89c4ff92SAndroid Build Coastguard Worker 
931*89c4ff92SAndroid Build Coastguard Worker     // Load it into the runtime. It should pass.
932*89c4ff92SAndroid Build Coastguard Worker     NetworkId netId;
933*89c4ff92SAndroid Build Coastguard Worker     runtime->LoadNetwork(netId, std::move(optNet));
934*89c4ff92SAndroid Build Coastguard Worker 
935*89c4ff92SAndroid Build Coastguard Worker     // Creates structures for input & output
936*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputData0
937*89c4ff92SAndroid Build Coastguard Worker     {
938*89c4ff92SAndroid Build Coastguard Worker         1.0f, 1.0f, 2.0f, 2.0f, 2.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f, 6.0f, 6.0f
939*89c4ff92SAndroid Build Coastguard Worker     };
940*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputData1
941*89c4ff92SAndroid Build Coastguard Worker     {
942*89c4ff92SAndroid Build Coastguard Worker         0.0f, 1.0f, 1.0f, 2.0f, 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f, 6.0f
943*89c4ff92SAndroid Build Coastguard Worker     };
944*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputData2
945*89c4ff92SAndroid Build Coastguard Worker     {
946*89c4ff92SAndroid Build Coastguard Worker         12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f
947*89c4ff92SAndroid Build Coastguard Worker     };
948*89c4ff92SAndroid Build Coastguard Worker 
949*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> outputData(12);
950*89c4ff92SAndroid Build Coastguard Worker 
951*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expectedOutput
952*89c4ff92SAndroid Build Coastguard Worker     {
953*89c4ff92SAndroid Build Coastguard Worker         11.0f, 9.0f, 7.0f, 5.0f, 3.0f, 1.0f, -1.0f, -3.0f, -5.0f, -7.0f, -9.0f, -11.0f
954*89c4ff92SAndroid Build Coastguard Worker     };
955*89c4ff92SAndroid Build Coastguard Worker 
956*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo0 = runtime->GetInputTensorInfo(netId, 0);
957*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo1 = runtime->GetInputTensorInfo(netId, 1);
958*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo2 = runtime->GetInputTensorInfo(netId, 2);
959*89c4ff92SAndroid Build Coastguard Worker     inputTensorInfo0.SetConstant(true);
960*89c4ff92SAndroid Build Coastguard Worker     inputTensorInfo1.SetConstant(true);
961*89c4ff92SAndroid Build Coastguard Worker     inputTensorInfo2.SetConstant(true);
962*89c4ff92SAndroid Build Coastguard Worker 
963*89c4ff92SAndroid Build Coastguard Worker     InputTensors inputTensors
964*89c4ff92SAndroid Build Coastguard Worker     {
965*89c4ff92SAndroid Build Coastguard Worker         { 0, armnn::ConstTensor(inputTensorInfo0, inputData0.data()) },
966*89c4ff92SAndroid Build Coastguard Worker         { 1, armnn::ConstTensor(inputTensorInfo1, inputData1.data()) },
967*89c4ff92SAndroid Build Coastguard Worker         { 2, armnn::ConstTensor(inputTensorInfo2, inputData2.data()) }
968*89c4ff92SAndroid Build Coastguard Worker     };
969*89c4ff92SAndroid Build Coastguard Worker     OutputTensors outputTensors
970*89c4ff92SAndroid Build Coastguard Worker     {
971*89c4ff92SAndroid Build Coastguard Worker         { 0,armnn::Tensor(runtime->GetOutputTensorInfo(netId, 0), outputData.data()) }
972*89c4ff92SAndroid Build Coastguard Worker     };
973*89c4ff92SAndroid Build Coastguard Worker 
974*89c4ff92SAndroid Build Coastguard Worker     runtime->GetProfiler(netId)->EnableProfiling(true);
975*89c4ff92SAndroid Build Coastguard Worker 
976*89c4ff92SAndroid Build Coastguard Worker     // Do the inference
977*89c4ff92SAndroid Build Coastguard Worker     runtime->EnqueueWorkload(netId, inputTensors, outputTensors);
978*89c4ff92SAndroid Build Coastguard Worker 
979*89c4ff92SAndroid Build Coastguard Worker     // Retrieve the Profiler.Print() output to get the workload execution
980*89c4ff92SAndroid Build Coastguard Worker     ProfilerManager& profilerManager = armnn::ProfilerManager::GetInstance();
981*89c4ff92SAndroid Build Coastguard Worker     std::stringstream ss;
982*89c4ff92SAndroid Build Coastguard Worker     profilerManager.GetProfiler()->Print(ss);;
983*89c4ff92SAndroid Build Coastguard Worker     std::string dump = ss.str();
984*89c4ff92SAndroid Build Coastguard Worker 
985*89c4ff92SAndroid Build Coastguard Worker     // Executed Subtraction using GpuAcc
986*89c4ff92SAndroid Build Coastguard Worker     std::size_t found = dump.find("ClSubtractionWorkload_Execute");
987*89c4ff92SAndroid Build Coastguard Worker     CHECK(found != std::string::npos);
988*89c4ff92SAndroid Build Coastguard Worker 
989*89c4ff92SAndroid Build Coastguard Worker     // Contain CopyMemGeneric
990*89c4ff92SAndroid Build Coastguard Worker     found = dump.find("CopyMemGeneric");
991*89c4ff92SAndroid Build Coastguard Worker     CHECK(found != std::string::npos);
992*89c4ff92SAndroid Build Coastguard Worker 
993*89c4ff92SAndroid Build Coastguard Worker     // Check output is as expected
994*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputData == expectedOutput);
995*89c4ff92SAndroid Build Coastguard Worker }
996*89c4ff92SAndroid Build Coastguard Worker 
997*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("NeonImportEnabledFallbackSubgraphToCl")
998*89c4ff92SAndroid Build Coastguard Worker {
999*89c4ff92SAndroid Build Coastguard Worker     using namespace armnn;
1000*89c4ff92SAndroid Build Coastguard Worker 
1001*89c4ff92SAndroid Build Coastguard Worker     IRuntime::CreationOptions options;
1002*89c4ff92SAndroid Build Coastguard Worker     IRuntimePtr runtime(IRuntime::Create(options));
1003*89c4ff92SAndroid Build Coastguard Worker 
1004*89c4ff92SAndroid Build Coastguard Worker     // Builds up the structure of the network.
1005*89c4ff92SAndroid Build Coastguard Worker     INetworkPtr net(INetwork::Create());
1006*89c4ff92SAndroid Build Coastguard Worker 
1007*89c4ff92SAndroid Build Coastguard Worker     Pooling2dDescriptor desc;
1008*89c4ff92SAndroid Build Coastguard Worker     desc.m_PoolWidth = 2;
1009*89c4ff92SAndroid Build Coastguard Worker     desc.m_PoolHeight = 2;
1010*89c4ff92SAndroid Build Coastguard Worker     desc.m_StrideX = 2;
1011*89c4ff92SAndroid Build Coastguard Worker     desc.m_StrideY = 2;
1012*89c4ff92SAndroid Build Coastguard Worker 
1013*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* input0 = net->AddInputLayer(0, "input0");
1014*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* input1 = net->AddInputLayer(1, "input1");
1015*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* input2 = net->AddInputLayer(2, "input2");
1016*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* add = net->AddElementwiseBinaryLayer(BinaryOperation::Add, "add");
1017*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* sub = net->AddElementwiseBinaryLayer(BinaryOperation::Sub, "sub");
1018*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* pooling = net->AddPooling2dLayer(desc, "pooling");
1019*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* output = net->AddOutputLayer(0, "output");
1020*89c4ff92SAndroid Build Coastguard Worker 
1021*89c4ff92SAndroid Build Coastguard Worker     input0->GetOutputSlot(0).Connect(add->GetInputSlot(0));
1022*89c4ff92SAndroid Build Coastguard Worker     input1->GetOutputSlot(0).Connect(add->GetInputSlot(1));
1023*89c4ff92SAndroid Build Coastguard Worker     input2->GetOutputSlot(0).Connect(sub->GetInputSlot(0));
1024*89c4ff92SAndroid Build Coastguard Worker     add->GetOutputSlot(0).Connect(sub->GetInputSlot(1));
1025*89c4ff92SAndroid Build Coastguard Worker     sub->GetOutputSlot(0).Connect(pooling->GetInputSlot(0));
1026*89c4ff92SAndroid Build Coastguard Worker     pooling->GetOutputSlot(0).Connect(output->GetInputSlot(0));
1027*89c4ff92SAndroid Build Coastguard Worker 
1028*89c4ff92SAndroid Build Coastguard Worker     TensorInfo info = TensorInfo({ 1, 2, 4, 2 }, DataType::Float32);
1029*89c4ff92SAndroid Build Coastguard Worker     TensorInfo poolingInfo = TensorInfo({ 1, 2, 2, 1 }, DataType::Float32);
1030*89c4ff92SAndroid Build Coastguard Worker 
1031*89c4ff92SAndroid Build Coastguard Worker     input0->GetOutputSlot(0).SetTensorInfo(info);
1032*89c4ff92SAndroid Build Coastguard Worker     input1->GetOutputSlot(0).SetTensorInfo(info);
1033*89c4ff92SAndroid Build Coastguard Worker     input2->GetOutputSlot(0).SetTensorInfo(info);
1034*89c4ff92SAndroid Build Coastguard Worker     add->GetOutputSlot(0).SetTensorInfo(info);
1035*89c4ff92SAndroid Build Coastguard Worker     sub->GetOutputSlot(0).SetTensorInfo(info);
1036*89c4ff92SAndroid Build Coastguard Worker     pooling->GetOutputSlot(0).SetTensorInfo(poolingInfo);
1037*89c4ff92SAndroid Build Coastguard Worker 
1038*89c4ff92SAndroid Build Coastguard Worker     std::vector<BackendId> backends = { Compute::CpuAcc, Compute::GpuAcc };
1039*89c4ff92SAndroid Build Coastguard Worker     // Use BackendSelectionHint to specify GpuAcc for Subtraction layer
1040*89c4ff92SAndroid Build Coastguard Worker     sub->BackendSelectionHint(backends[1]);
1041*89c4ff92SAndroid Build Coastguard Worker 
1042*89c4ff92SAndroid Build Coastguard Worker     // optimize the network
1043*89c4ff92SAndroid Build Coastguard Worker     OptimizerOptionsOpaque optOptions;
1044*89c4ff92SAndroid Build Coastguard Worker     optOptions.SetImportEnabled(true);
1045*89c4ff92SAndroid Build Coastguard Worker     optOptions.SetExportEnabled(true);
1046*89c4ff92SAndroid Build Coastguard Worker     IOptimizedNetworkPtr optNet = Optimize(*net, backends, runtime->GetDeviceSpec(), optOptions);
1047*89c4ff92SAndroid Build Coastguard Worker 
1048*89c4ff92SAndroid Build Coastguard Worker     Graph& graph = GetGraphForTesting(optNet.get());
1049*89c4ff92SAndroid Build Coastguard Worker 
1050*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer0 = GetFirstLayerWithName(graph, "input0");
1051*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer1 = GetFirstLayerWithName(graph, "input1");
1052*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer2 = GetFirstLayerWithName(graph, "input2");
1053*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer3 = GetFirstLayerWithName(graph, "add");
1054*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer4 = GetFirstLayerWithName(graph, "[ add (0) -> sub (1) ]");
1055*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer5 = GetFirstLayerWithName(graph, "sub");
1056*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer6 = GetFirstLayerWithName(graph, "[ sub (0) -> pooling (0) ]");
1057*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer7 = GetFirstLayerWithName(graph, "pooling");
1058*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer8 = GetFirstLayerWithName(graph, "output");
1059*89c4ff92SAndroid Build Coastguard Worker 
1060*89c4ff92SAndroid Build Coastguard Worker     // Checks order is valid.
1061*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer0, layer1));
1062*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer1, layer2));
1063*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer2, layer3));
1064*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer3, layer4));
1065*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer4, layer5));
1066*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer5, layer6));
1067*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer6, layer7));
1068*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer7, layer8));
1069*89c4ff92SAndroid Build Coastguard Worker 
1070*89c4ff92SAndroid Build Coastguard Worker     // Use memory import between backends
1071*89c4ff92SAndroid Build Coastguard Worker     CHECK((layer4->GetType() == LayerType::MemCopy));
1072*89c4ff92SAndroid Build Coastguard Worker     CHECK((layer6->GetType() == LayerType::MemCopy));
1073*89c4ff92SAndroid Build Coastguard Worker 
1074*89c4ff92SAndroid Build Coastguard Worker     // Correctly use backend hint
1075*89c4ff92SAndroid Build Coastguard Worker     CHECK((layer5->GetBackendId() == Compute::GpuAcc ));
1076*89c4ff92SAndroid Build Coastguard Worker 
1077*89c4ff92SAndroid Build Coastguard Worker     // Load it into the runtime. It should pass.
1078*89c4ff92SAndroid Build Coastguard Worker     NetworkId netId;
1079*89c4ff92SAndroid Build Coastguard Worker     std::string ignoredErrorMessage;
1080*89c4ff92SAndroid Build Coastguard Worker 
1081*89c4ff92SAndroid Build Coastguard Worker     INetworkProperties networkProperties(false, MemorySource::Malloc, MemorySource::Malloc);
1082*89c4ff92SAndroid Build Coastguard Worker 
1083*89c4ff92SAndroid Build Coastguard Worker     runtime->LoadNetwork(netId, std::move(optNet), ignoredErrorMessage, networkProperties);
1084*89c4ff92SAndroid Build Coastguard Worker 
1085*89c4ff92SAndroid Build Coastguard Worker     // Creates structures for input & output
1086*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputData0
1087*89c4ff92SAndroid Build Coastguard Worker     {
1088*89c4ff92SAndroid Build Coastguard Worker         1.0f, 1.0f, 2.0f, 2.0f, 2.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f, 6.0f, 6.0f, 1.0f, 1.0f, 2.0f, 2.0f
1089*89c4ff92SAndroid Build Coastguard Worker     };
1090*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputData1
1091*89c4ff92SAndroid Build Coastguard Worker     {
1092*89c4ff92SAndroid Build Coastguard Worker         0.0f, 1.0f, 1.0f, 2.0f, 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f, 6.0f, 0.0f, 1.0f, 1.0f, 2.0f
1093*89c4ff92SAndroid Build Coastguard Worker     };
1094*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputData2
1095*89c4ff92SAndroid Build Coastguard Worker     {
1096*89c4ff92SAndroid Build Coastguard Worker         12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f, 12.0f, 11.0f, 10.0f, 9.0f
1097*89c4ff92SAndroid Build Coastguard Worker     };
1098*89c4ff92SAndroid Build Coastguard Worker 
1099*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> outputData(4);
1100*89c4ff92SAndroid Build Coastguard Worker 
1101*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expectedOutput{ 11.0f, 3.0f, -5.0f, 11.0f };
1102*89c4ff92SAndroid Build Coastguard Worker 
1103*89c4ff92SAndroid Build Coastguard Worker     // Prepare aligned data
1104*89c4ff92SAndroid Build Coastguard Worker     unsigned int numElements = info.GetNumElements();
1105*89c4ff92SAndroid Build Coastguard Worker     size_t totalBytes = numElements * sizeof(float);
1106*89c4ff92SAndroid Build Coastguard Worker     const size_t alignment = 64;
1107*89c4ff92SAndroid Build Coastguard Worker     size_t space = totalBytes + alignment + alignment;
1108*89c4ff92SAndroid Build Coastguard Worker     auto inputData = std::make_unique<uint8_t[]>(space);
1109*89c4ff92SAndroid Build Coastguard Worker     void* alignedInputPtr = inputData.get();
1110*89c4ff92SAndroid Build Coastguard Worker     CHECK(std::align(alignment, totalBytes, alignedInputPtr, space));
1111*89c4ff92SAndroid Build Coastguard Worker 
1112*89c4ff92SAndroid Build Coastguard Worker     auto* intputPtr = reinterpret_cast<float*>(alignedInputPtr);
1113*89c4ff92SAndroid Build Coastguard Worker     std::copy(inputData2.begin(), inputData2.end(), intputPtr);
1114*89c4ff92SAndroid Build Coastguard Worker 
1115*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo0 = runtime->GetInputTensorInfo(netId, 0);
1116*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo1 = runtime->GetInputTensorInfo(netId, 1);
1117*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo2 = runtime->GetInputTensorInfo(netId, 2);
1118*89c4ff92SAndroid Build Coastguard Worker     inputTensorInfo0.SetConstant(true);
1119*89c4ff92SAndroid Build Coastguard Worker     inputTensorInfo1.SetConstant(true);
1120*89c4ff92SAndroid Build Coastguard Worker     inputTensorInfo2.SetConstant(true);
1121*89c4ff92SAndroid Build Coastguard Worker 
1122*89c4ff92SAndroid Build Coastguard Worker     InputTensors inputTensors
1123*89c4ff92SAndroid Build Coastguard Worker     {
1124*89c4ff92SAndroid Build Coastguard Worker         { 0, armnn::ConstTensor(inputTensorInfo0, inputData0.data()) },
1125*89c4ff92SAndroid Build Coastguard Worker         { 1, armnn::ConstTensor(inputTensorInfo1, inputData1.data()) },
1126*89c4ff92SAndroid Build Coastguard Worker         { 2, armnn::ConstTensor(inputTensorInfo2, alignedInputPtr) }
1127*89c4ff92SAndroid Build Coastguard Worker     };
1128*89c4ff92SAndroid Build Coastguard Worker     OutputTensors outputTensors
1129*89c4ff92SAndroid Build Coastguard Worker     {
1130*89c4ff92SAndroid Build Coastguard Worker         { 0,armnn::Tensor(runtime->GetOutputTensorInfo(netId, 0), outputData.data()) }
1131*89c4ff92SAndroid Build Coastguard Worker     };
1132*89c4ff92SAndroid Build Coastguard Worker 
1133*89c4ff92SAndroid Build Coastguard Worker     runtime->GetProfiler(netId)->EnableProfiling(true);
1134*89c4ff92SAndroid Build Coastguard Worker 
1135*89c4ff92SAndroid Build Coastguard Worker     // Do the inference
1136*89c4ff92SAndroid Build Coastguard Worker     runtime->EnqueueWorkload(netId, inputTensors, outputTensors);
1137*89c4ff92SAndroid Build Coastguard Worker 
1138*89c4ff92SAndroid Build Coastguard Worker     // Retrieve the Profiler.Print() output to get the workload execution
1139*89c4ff92SAndroid Build Coastguard Worker     ProfilerManager& profilerManager = armnn::ProfilerManager::GetInstance();
1140*89c4ff92SAndroid Build Coastguard Worker     std::stringstream ss;
1141*89c4ff92SAndroid Build Coastguard Worker     profilerManager.GetProfiler()->Print(ss);;
1142*89c4ff92SAndroid Build Coastguard Worker     std::string dump = ss.str();
1143*89c4ff92SAndroid Build Coastguard Worker 
1144*89c4ff92SAndroid Build Coastguard Worker     // Executed Subtraction using GpuAcc
1145*89c4ff92SAndroid Build Coastguard Worker     std::size_t found = dump.find("ClSubtractionWorkload_Execute");
1146*89c4ff92SAndroid Build Coastguard Worker     CHECK(found != std::string::npos);
1147*89c4ff92SAndroid Build Coastguard Worker 
1148*89c4ff92SAndroid Build Coastguard Worker     // Correctly switch back to CpuAcc
1149*89c4ff92SAndroid Build Coastguard Worker     found = dump.find("NeonPooling2dWorkload_Execute");
1150*89c4ff92SAndroid Build Coastguard Worker     CHECK(found != std::string::npos);
1151*89c4ff92SAndroid Build Coastguard Worker 
1152*89c4ff92SAndroid Build Coastguard Worker     // Contain CopyMemGeneric
1153*89c4ff92SAndroid Build Coastguard Worker     found = dump.find("CopyMemGeneric");
1154*89c4ff92SAndroid Build Coastguard Worker     CHECK(found != std::string::npos);
1155*89c4ff92SAndroid Build Coastguard Worker 
1156*89c4ff92SAndroid Build Coastguard Worker     // Contains SyncMemGeneric for output
1157*89c4ff92SAndroid Build Coastguard Worker     found = dump.find("SyncMemGeneric");
1158*89c4ff92SAndroid Build Coastguard Worker     CHECK(found != std::string::npos);
1159*89c4ff92SAndroid Build Coastguard Worker 
1160*89c4ff92SAndroid Build Coastguard Worker     // Check output is as expected
1161*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputData == expectedOutput);
1162*89c4ff92SAndroid Build Coastguard Worker     runtime->UnloadNetwork(netId);
1163*89c4ff92SAndroid Build Coastguard Worker }
1164*89c4ff92SAndroid Build Coastguard Worker 
1165*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("NeonImportDisableFallbackSubgraphToCl")
1166*89c4ff92SAndroid Build Coastguard Worker {
1167*89c4ff92SAndroid Build Coastguard Worker     using namespace armnn;
1168*89c4ff92SAndroid Build Coastguard Worker 
1169*89c4ff92SAndroid Build Coastguard Worker     IRuntime::CreationOptions options;
1170*89c4ff92SAndroid Build Coastguard Worker     IRuntimePtr runtime(IRuntime::Create(options));
1171*89c4ff92SAndroid Build Coastguard Worker 
1172*89c4ff92SAndroid Build Coastguard Worker     // Builds up the structure of the network.
1173*89c4ff92SAndroid Build Coastguard Worker     INetworkPtr net(INetwork::Create());
1174*89c4ff92SAndroid Build Coastguard Worker 
1175*89c4ff92SAndroid Build Coastguard Worker     Pooling2dDescriptor desc;
1176*89c4ff92SAndroid Build Coastguard Worker 
1177*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* input0 = net->AddInputLayer(0, "input0");
1178*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* input1 = net->AddInputLayer(1, "input1");
1179*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* input2 = net->AddInputLayer(2, "input2");
1180*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* add = net->AddElementwiseBinaryLayer(BinaryOperation::Add, "add");
1181*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* sub = net->AddElementwiseBinaryLayer(BinaryOperation::Sub, "sub");
1182*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* pooling = net->AddPooling2dLayer(desc, "pooling");
1183*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* output = net->AddOutputLayer(0, "output");
1184*89c4ff92SAndroid Build Coastguard Worker 
1185*89c4ff92SAndroid Build Coastguard Worker     input0->GetOutputSlot(0).Connect(add->GetInputSlot(0));
1186*89c4ff92SAndroid Build Coastguard Worker     input1->GetOutputSlot(0).Connect(add->GetInputSlot(1));
1187*89c4ff92SAndroid Build Coastguard Worker     input2->GetOutputSlot(0).Connect(sub->GetInputSlot(0));
1188*89c4ff92SAndroid Build Coastguard Worker     add->GetOutputSlot(0).Connect(sub->GetInputSlot(1));
1189*89c4ff92SAndroid Build Coastguard Worker     sub->GetOutputSlot(0).Connect(pooling->GetInputSlot(0));
1190*89c4ff92SAndroid Build Coastguard Worker     pooling->GetOutputSlot(0).Connect(output->GetInputSlot(0));
1191*89c4ff92SAndroid Build Coastguard Worker 
1192*89c4ff92SAndroid Build Coastguard Worker     TensorInfo info = TensorInfo({ 1, 2, 3, 2 }, DataType::Float32);
1193*89c4ff92SAndroid Build Coastguard Worker     TensorInfo poolingInfo = TensorInfo({ 1, 2, 1, 1 }, DataType::Float32);
1194*89c4ff92SAndroid Build Coastguard Worker 
1195*89c4ff92SAndroid Build Coastguard Worker     input0->GetOutputSlot(0).SetTensorInfo(info);
1196*89c4ff92SAndroid Build Coastguard Worker     input1->GetOutputSlot(0).SetTensorInfo(info);
1197*89c4ff92SAndroid Build Coastguard Worker     input2->GetOutputSlot(0).SetTensorInfo(info);
1198*89c4ff92SAndroid Build Coastguard Worker     add->GetOutputSlot(0).SetTensorInfo(info);
1199*89c4ff92SAndroid Build Coastguard Worker     sub->GetOutputSlot(0).SetTensorInfo(info);
1200*89c4ff92SAndroid Build Coastguard Worker     pooling->GetOutputSlot(0).SetTensorInfo(poolingInfo);
1201*89c4ff92SAndroid Build Coastguard Worker 
1202*89c4ff92SAndroid Build Coastguard Worker     std::vector<BackendId> backends = { Compute::CpuAcc, Compute::GpuAcc };
1203*89c4ff92SAndroid Build Coastguard Worker     // Use BackendSelectionHint to specify GpuAcc for Subtraction layer
1204*89c4ff92SAndroid Build Coastguard Worker     sub->BackendSelectionHint(backends[1]);
1205*89c4ff92SAndroid Build Coastguard Worker 
1206*89c4ff92SAndroid Build Coastguard Worker     // optimize the network
1207*89c4ff92SAndroid Build Coastguard Worker     OptimizerOptionsOpaque optOptions;
1208*89c4ff92SAndroid Build Coastguard Worker     IOptimizedNetworkPtr optNet = Optimize(*net, backends, runtime->GetDeviceSpec(), optOptions);
1209*89c4ff92SAndroid Build Coastguard Worker 
1210*89c4ff92SAndroid Build Coastguard Worker     Graph& graph = GetGraphForTesting(optNet.get());
1211*89c4ff92SAndroid Build Coastguard Worker 
1212*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer0 = GetFirstLayerWithName(graph, "input0");
1213*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer1 = GetFirstLayerWithName(graph, "input1");
1214*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer2 = GetFirstLayerWithName(graph, "input2");
1215*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer3 = GetFirstLayerWithName(graph, "add");
1216*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer4 = GetFirstLayerWithName(graph, "[ add (0) -> sub (1) ]");
1217*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer5 = GetFirstLayerWithName(graph, "sub");
1218*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer6 = GetFirstLayerWithName(graph, "[ sub (0) -> pooling (0) ]");
1219*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer7 = GetFirstLayerWithName(graph, "pooling");
1220*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer8 = GetFirstLayerWithName(graph, "output");
1221*89c4ff92SAndroid Build Coastguard Worker 
1222*89c4ff92SAndroid Build Coastguard Worker     // Checks order is valid.
1223*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer0, layer1));
1224*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer1, layer2));
1225*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer2, layer3));
1226*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer3, layer4));
1227*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer4, layer5));
1228*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer5, layer6));
1229*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer6, layer7));
1230*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer7, layer8));
1231*89c4ff92SAndroid Build Coastguard Worker 
1232*89c4ff92SAndroid Build Coastguard Worker     // Use memory import between backends
1233*89c4ff92SAndroid Build Coastguard Worker     CHECK((layer4->GetType() == LayerType::MemCopy));
1234*89c4ff92SAndroid Build Coastguard Worker     CHECK((layer6->GetType() == LayerType::MemCopy));
1235*89c4ff92SAndroid Build Coastguard Worker 
1236*89c4ff92SAndroid Build Coastguard Worker     // Correctly use backend hint
1237*89c4ff92SAndroid Build Coastguard Worker     CHECK((layer5->GetBackendId() == Compute::GpuAcc ));
1238*89c4ff92SAndroid Build Coastguard Worker 
1239*89c4ff92SAndroid Build Coastguard Worker     // Load it into the runtime. It should pass.
1240*89c4ff92SAndroid Build Coastguard Worker     NetworkId netId;
1241*89c4ff92SAndroid Build Coastguard Worker     runtime->LoadNetwork(netId, std::move(optNet));
1242*89c4ff92SAndroid Build Coastguard Worker 
1243*89c4ff92SAndroid Build Coastguard Worker     // Creates structures for input & output
1244*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputData0
1245*89c4ff92SAndroid Build Coastguard Worker     {
1246*89c4ff92SAndroid Build Coastguard Worker         1.0f, 1.0f, 2.0f, 2.0f, 2.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f, 6.0f, 6.0f
1247*89c4ff92SAndroid Build Coastguard Worker     };
1248*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputData1
1249*89c4ff92SAndroid Build Coastguard Worker     {
1250*89c4ff92SAndroid Build Coastguard Worker         0.0f, 1.0f, 1.0f, 2.0f, 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f, 6.0f
1251*89c4ff92SAndroid Build Coastguard Worker     };
1252*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputData2
1253*89c4ff92SAndroid Build Coastguard Worker     {
1254*89c4ff92SAndroid Build Coastguard Worker         12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f
1255*89c4ff92SAndroid Build Coastguard Worker     };
1256*89c4ff92SAndroid Build Coastguard Worker 
1257*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> outputData(2);
1258*89c4ff92SAndroid Build Coastguard Worker 
1259*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expectedOutput{ 11.0f, -1.0f };
1260*89c4ff92SAndroid Build Coastguard Worker 
1261*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo0 = runtime->GetInputTensorInfo(netId, 0);
1262*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo1 = runtime->GetInputTensorInfo(netId, 1);
1263*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo2 = runtime->GetInputTensorInfo(netId, 2);
1264*89c4ff92SAndroid Build Coastguard Worker     inputTensorInfo0.SetConstant(true);
1265*89c4ff92SAndroid Build Coastguard Worker     inputTensorInfo1.SetConstant(true);
1266*89c4ff92SAndroid Build Coastguard Worker     inputTensorInfo2.SetConstant(true);
1267*89c4ff92SAndroid Build Coastguard Worker 
1268*89c4ff92SAndroid Build Coastguard Worker     InputTensors inputTensors
1269*89c4ff92SAndroid Build Coastguard Worker     {
1270*89c4ff92SAndroid Build Coastguard Worker         { 0, armnn::ConstTensor(inputTensorInfo0, inputData0.data()) },
1271*89c4ff92SAndroid Build Coastguard Worker         { 1, armnn::ConstTensor(inputTensorInfo1, inputData1.data()) },
1272*89c4ff92SAndroid Build Coastguard Worker         { 2, armnn::ConstTensor(inputTensorInfo2, inputData2.data()) }
1273*89c4ff92SAndroid Build Coastguard Worker     };
1274*89c4ff92SAndroid Build Coastguard Worker     OutputTensors outputTensors
1275*89c4ff92SAndroid Build Coastguard Worker     {
1276*89c4ff92SAndroid Build Coastguard Worker         { 0,armnn::Tensor(runtime->GetOutputTensorInfo(netId, 0), outputData.data()) }
1277*89c4ff92SAndroid Build Coastguard Worker     };
1278*89c4ff92SAndroid Build Coastguard Worker 
1279*89c4ff92SAndroid Build Coastguard Worker     runtime->GetProfiler(netId)->EnableProfiling(true);
1280*89c4ff92SAndroid Build Coastguard Worker 
1281*89c4ff92SAndroid Build Coastguard Worker     // Do the inference
1282*89c4ff92SAndroid Build Coastguard Worker     runtime->EnqueueWorkload(netId, inputTensors, outputTensors);
1283*89c4ff92SAndroid Build Coastguard Worker 
1284*89c4ff92SAndroid Build Coastguard Worker     // Retrieve the Profiler.Print() output to get the workload execution
1285*89c4ff92SAndroid Build Coastguard Worker     ProfilerManager& profilerManager = armnn::ProfilerManager::GetInstance();
1286*89c4ff92SAndroid Build Coastguard Worker     std::stringstream ss;
1287*89c4ff92SAndroid Build Coastguard Worker     profilerManager.GetProfiler()->Print(ss);;
1288*89c4ff92SAndroid Build Coastguard Worker     std::string dump = ss.str();
1289*89c4ff92SAndroid Build Coastguard Worker 
1290*89c4ff92SAndroid Build Coastguard Worker     // Executed Subtraction using GpuAcc
1291*89c4ff92SAndroid Build Coastguard Worker     std::size_t found = dump.find("ClSubtractionWorkload_Execute");
1292*89c4ff92SAndroid Build Coastguard Worker     CHECK(found != std::string::npos);
1293*89c4ff92SAndroid Build Coastguard Worker 
1294*89c4ff92SAndroid Build Coastguard Worker     // Correctly switch back to CpuAcc
1295*89c4ff92SAndroid Build Coastguard Worker     found = dump.find("NeonPooling2dWorkload_Execute");
1296*89c4ff92SAndroid Build Coastguard Worker     CHECK(found != std::string::npos);
1297*89c4ff92SAndroid Build Coastguard Worker 
1298*89c4ff92SAndroid Build Coastguard Worker     // Contain CopyMemGeneric
1299*89c4ff92SAndroid Build Coastguard Worker     found = dump.find("CopyMemGeneric");
1300*89c4ff92SAndroid Build Coastguard Worker     CHECK(found != std::string::npos);
1301*89c4ff92SAndroid Build Coastguard Worker 
1302*89c4ff92SAndroid Build Coastguard Worker     // Check output is as expected
1303*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputData == expectedOutput);
1304*89c4ff92SAndroid Build Coastguard Worker }
1305*89c4ff92SAndroid Build Coastguard Worker #endif
1306*89c4ff92SAndroid Build Coastguard Worker 
1307*89c4ff92SAndroid Build Coastguard Worker }
1308