xref: /aosp_15_r20/external/armnn/src/backends/backendsCommon/test/EndToEndTestImpl.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #pragma once
6*89c4ff92SAndroid Build Coastguard Worker 
7*89c4ff92SAndroid Build Coastguard Worker #include <CommonTestUtils.hpp>
8*89c4ff92SAndroid Build Coastguard Worker 
9*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Descriptors.hpp>
10*89c4ff92SAndroid Build Coastguard Worker #include <armnn/INetwork.hpp>
11*89c4ff92SAndroid Build Coastguard Worker #include <armnn/IRuntime.hpp>
12*89c4ff92SAndroid Build Coastguard Worker 
13*89c4ff92SAndroid Build Coastguard Worker #include <Profiling.hpp>
14*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/QuantizeHelper.hpp>
15*89c4ff92SAndroid Build Coastguard Worker #include <ResolveType.hpp>
16*89c4ff92SAndroid Build Coastguard Worker 
17*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
18*89c4ff92SAndroid Build Coastguard Worker 
19*89c4ff92SAndroid Build Coastguard Worker #include <vector>
20*89c4ff92SAndroid Build Coastguard Worker 
21*89c4ff92SAndroid Build Coastguard Worker namespace
22*89c4ff92SAndroid Build Coastguard Worker {
23*89c4ff92SAndroid Build Coastguard Worker 
24*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
25*89c4ff92SAndroid Build Coastguard Worker 
26*89c4ff92SAndroid Build Coastguard Worker template<typename T>
ConstantUsageTest(const std::vector<BackendId> & computeDevice,const TensorInfo & commonTensorInfo,const std::vector<T> & inputData,const std::vector<T> & constantData,const std::vector<T> & expectedOutputData)27*89c4ff92SAndroid Build Coastguard Worker bool ConstantUsageTest(const std::vector<BackendId>& computeDevice,
28*89c4ff92SAndroid Build Coastguard Worker                        const TensorInfo& commonTensorInfo,
29*89c4ff92SAndroid Build Coastguard Worker                        const std::vector<T>& inputData,
30*89c4ff92SAndroid Build Coastguard Worker                        const std::vector<T>& constantData,
31*89c4ff92SAndroid Build Coastguard Worker                        const std::vector<T>& expectedOutputData)
32*89c4ff92SAndroid Build Coastguard Worker {
33*89c4ff92SAndroid Build Coastguard Worker     // Create runtime in which test will run
34*89c4ff92SAndroid Build Coastguard Worker     IRuntime::CreationOptions options;
35*89c4ff92SAndroid Build Coastguard Worker     IRuntimePtr runtime(IRuntime::Create(options));
36*89c4ff92SAndroid Build Coastguard Worker 
37*89c4ff92SAndroid Build Coastguard Worker     // Builds up the structure of the network.
38*89c4ff92SAndroid Build Coastguard Worker     INetworkPtr net(INetwork::Create());
39*89c4ff92SAndroid Build Coastguard Worker 
40*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* input = net->AddInputLayer(0);
41*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* constant = net->AddConstantLayer(ConstTensor(commonTensorInfo, constantData));
42*89c4ff92SAndroid Build Coastguard Worker     ARMNN_NO_DEPRECATE_WARN_BEGIN
43*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* add = net->AddAdditionLayer();
44*89c4ff92SAndroid Build Coastguard Worker     ARMNN_NO_DEPRECATE_WARN_END
45*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* output = net->AddOutputLayer(0);
46*89c4ff92SAndroid Build Coastguard Worker 
47*89c4ff92SAndroid Build Coastguard Worker     input->GetOutputSlot(0).Connect(add->GetInputSlot(0));
48*89c4ff92SAndroid Build Coastguard Worker     constant->GetOutputSlot(0).Connect(add->GetInputSlot(1));
49*89c4ff92SAndroid Build Coastguard Worker     add->GetOutputSlot(0).Connect(output->GetInputSlot(0));
50*89c4ff92SAndroid Build Coastguard Worker 
51*89c4ff92SAndroid Build Coastguard Worker     // Sets the tensors in the network.
52*89c4ff92SAndroid Build Coastguard Worker     input->GetOutputSlot(0).SetTensorInfo(commonTensorInfo);
53*89c4ff92SAndroid Build Coastguard Worker     constant->GetOutputSlot(0).SetTensorInfo(commonTensorInfo);
54*89c4ff92SAndroid Build Coastguard Worker     add->GetOutputSlot(0).SetTensorInfo(commonTensorInfo);
55*89c4ff92SAndroid Build Coastguard Worker 
56*89c4ff92SAndroid Build Coastguard Worker     // optimize the network
57*89c4ff92SAndroid Build Coastguard Worker     IOptimizedNetworkPtr optNet = Optimize(*net, computeDevice, runtime->GetDeviceSpec());
58*89c4ff92SAndroid Build Coastguard Worker 
59*89c4ff92SAndroid Build Coastguard Worker     // Loads it into the runtime.
60*89c4ff92SAndroid Build Coastguard Worker     NetworkId netId;
61*89c4ff92SAndroid Build Coastguard Worker     std::string errorMessage;
62*89c4ff92SAndroid Build Coastguard Worker     armnn::Status loadingStatus = runtime->LoadNetwork(netId, std::move(optNet), errorMessage);
63*89c4ff92SAndroid Build Coastguard Worker     CHECK_MESSAGE(loadingStatus == Status::Success, errorMessage);
64*89c4ff92SAndroid Build Coastguard Worker 
65*89c4ff92SAndroid Build Coastguard Worker     // Creates structures for input & output.
66*89c4ff92SAndroid Build Coastguard Worker     std::vector<T> outputData(inputData.size());
67*89c4ff92SAndroid Build Coastguard Worker 
68*89c4ff92SAndroid Build Coastguard Worker     InputTensors inputTensors
69*89c4ff92SAndroid Build Coastguard Worker     {
70*89c4ff92SAndroid Build Coastguard Worker         {0, ConstTensor(runtime->GetInputTensorInfo(netId, 0), inputData.data())}
71*89c4ff92SAndroid Build Coastguard Worker     };
72*89c4ff92SAndroid Build Coastguard Worker     OutputTensors outputTensors
73*89c4ff92SAndroid Build Coastguard Worker     {
74*89c4ff92SAndroid Build Coastguard Worker         {0, Tensor(runtime->GetOutputTensorInfo(netId, 0), outputData.data())}
75*89c4ff92SAndroid Build Coastguard Worker     };
76*89c4ff92SAndroid Build Coastguard Worker 
77*89c4ff92SAndroid Build Coastguard Worker     // Does the inference.
78*89c4ff92SAndroid Build Coastguard Worker     runtime->EnqueueWorkload(netId, inputTensors, outputTensors);
79*89c4ff92SAndroid Build Coastguard Worker 
80*89c4ff92SAndroid Build Coastguard Worker     // Checks the results.
81*89c4ff92SAndroid Build Coastguard Worker     return outputData == expectedOutputData;
82*89c4ff92SAndroid Build Coastguard Worker }
83*89c4ff92SAndroid Build Coastguard Worker 
ConstantUsageFloat32Test(const std::vector<BackendId> & backends)84*89c4ff92SAndroid Build Coastguard Worker inline bool ConstantUsageFloat32Test(const std::vector<BackendId>& backends)
85*89c4ff92SAndroid Build Coastguard Worker {
86*89c4ff92SAndroid Build Coastguard Worker     TensorInfo commonTensorInfo({ 2, 3 }, DataType::Float32);
87*89c4ff92SAndroid Build Coastguard Worker     commonTensorInfo.SetConstant(true);
88*89c4ff92SAndroid Build Coastguard Worker 
89*89c4ff92SAndroid Build Coastguard Worker     return ConstantUsageTest(backends,
90*89c4ff92SAndroid Build Coastguard Worker         commonTensorInfo,
91*89c4ff92SAndroid Build Coastguard Worker         std::vector<float>{ 1.f, 2.f, 3.f, 4.f, 5.f, 6.f }, // Input.
92*89c4ff92SAndroid Build Coastguard Worker         std::vector<float>{ 6.f, 5.f, 4.f, 3.f, 2.f, 1.f }, // Const input.
93*89c4ff92SAndroid Build Coastguard Worker         std::vector<float>{ 7.f, 7.f, 7.f, 7.f, 7.f, 7.f }  // Expected output.
94*89c4ff92SAndroid Build Coastguard Worker     );
95*89c4ff92SAndroid Build Coastguard Worker }
96*89c4ff92SAndroid Build Coastguard Worker 
ConstantUsageUint8Test(const std::vector<BackendId> & backends)97*89c4ff92SAndroid Build Coastguard Worker inline bool ConstantUsageUint8Test(const std::vector<BackendId>& backends)
98*89c4ff92SAndroid Build Coastguard Worker {
99*89c4ff92SAndroid Build Coastguard Worker     TensorInfo commonTensorInfo({ 2, 3 }, DataType::QAsymmU8);
100*89c4ff92SAndroid Build Coastguard Worker 
101*89c4ff92SAndroid Build Coastguard Worker     const float scale = 0.023529f;
102*89c4ff92SAndroid Build Coastguard Worker     const int8_t offset = -43;
103*89c4ff92SAndroid Build Coastguard Worker 
104*89c4ff92SAndroid Build Coastguard Worker     commonTensorInfo.SetQuantizationScale(scale);
105*89c4ff92SAndroid Build Coastguard Worker     commonTensorInfo.SetQuantizationOffset(offset);
106*89c4ff92SAndroid Build Coastguard Worker     commonTensorInfo.SetConstant(true);
107*89c4ff92SAndroid Build Coastguard Worker 
108*89c4ff92SAndroid Build Coastguard Worker     return ConstantUsageTest(backends,
109*89c4ff92SAndroid Build Coastguard Worker         commonTensorInfo,
110*89c4ff92SAndroid Build Coastguard Worker         armnnUtils::QuantizedVector<uint8_t>({ 1.f, 2.f, 3.f, 4.f, 5.f, 6.f }, scale, offset), // Input.
111*89c4ff92SAndroid Build Coastguard Worker         armnnUtils::QuantizedVector<uint8_t>({ 6.f, 5.f, 4.f, 3.f, 2.f, 1.f }, scale, offset), // Const input.
112*89c4ff92SAndroid Build Coastguard Worker         armnnUtils::QuantizedVector<uint8_t>({ 7.f, 7.f, 7.f, 7.f, 7.f, 7.f }, scale, offset)  // Expected output.
113*89c4ff92SAndroid Build Coastguard Worker     );
114*89c4ff92SAndroid Build Coastguard Worker }
115*89c4ff92SAndroid Build Coastguard Worker 
116*89c4ff92SAndroid Build Coastguard Worker // Utility function to find the number of instances of a substring within a string.
SubStringCounter(std::string & string,std::string && substring)117*89c4ff92SAndroid Build Coastguard Worker int SubStringCounter(std::string& string, std::string&& substring)
118*89c4ff92SAndroid Build Coastguard Worker {
119*89c4ff92SAndroid Build Coastguard Worker     std::size_t found = 0;
120*89c4ff92SAndroid Build Coastguard Worker     int count = 0;
121*89c4ff92SAndroid Build Coastguard Worker     // Look for the substring starting from where we last found the substring
122*89c4ff92SAndroid Build Coastguard Worker     while((found = string.find(substring, found)) != std::string::npos)
123*89c4ff92SAndroid Build Coastguard Worker     {
124*89c4ff92SAndroid Build Coastguard Worker         count++;
125*89c4ff92SAndroid Build Coastguard Worker         // Offset by substring length to avoid finding the same substring twice
126*89c4ff92SAndroid Build Coastguard Worker         found += substring.length();
127*89c4ff92SAndroid Build Coastguard Worker     }
128*89c4ff92SAndroid Build Coastguard Worker     return count;
129*89c4ff92SAndroid Build Coastguard Worker }
130*89c4ff92SAndroid Build Coastguard Worker 
131*89c4ff92SAndroid Build Coastguard Worker template<DataType ArmnnIType, DataType ArmnnOType,
132*89c4ff92SAndroid Build Coastguard Worker          typename TInput = ResolveType<ArmnnIType>, typename TOutput = ResolveType<ArmnnOType>>
EndToEndLayerTestImpl(INetworkPtr network,const std::map<int,std::vector<TInput>> & inputTensorData,const std::map<int,std::vector<TOutput>> & expectedOutputData,std::vector<BackendId> backends,float tolerance=0.000001f)133*89c4ff92SAndroid Build Coastguard Worker void EndToEndLayerTestImpl(INetworkPtr network,
134*89c4ff92SAndroid Build Coastguard Worker                            const std::map<int, std::vector<TInput>>& inputTensorData,
135*89c4ff92SAndroid Build Coastguard Worker                            const std::map<int, std::vector<TOutput>>& expectedOutputData,
136*89c4ff92SAndroid Build Coastguard Worker                            std::vector<BackendId> backends,
137*89c4ff92SAndroid Build Coastguard Worker                            float tolerance = 0.000001f)
138*89c4ff92SAndroid Build Coastguard Worker {
139*89c4ff92SAndroid Build Coastguard Worker     // Create runtime in which test will run
140*89c4ff92SAndroid Build Coastguard Worker     IRuntime::CreationOptions options;
141*89c4ff92SAndroid Build Coastguard Worker     IRuntimePtr runtime(IRuntime::Create(options));
142*89c4ff92SAndroid Build Coastguard Worker 
143*89c4ff92SAndroid Build Coastguard Worker     // optimize the network
144*89c4ff92SAndroid Build Coastguard Worker     IOptimizedNetworkPtr optNet = Optimize(*network, backends, runtime->GetDeviceSpec());
145*89c4ff92SAndroid Build Coastguard Worker 
146*89c4ff92SAndroid Build Coastguard Worker     // Loads it into the runtime.
147*89c4ff92SAndroid Build Coastguard Worker     NetworkId netId;
148*89c4ff92SAndroid Build Coastguard Worker     std::string errorMessage;
149*89c4ff92SAndroid Build Coastguard Worker     armnn::Status loadingStatus = runtime->LoadNetwork(netId, std::move(optNet), errorMessage);
150*89c4ff92SAndroid Build Coastguard Worker     CHECK_MESSAGE(loadingStatus == Status::Success, errorMessage);
151*89c4ff92SAndroid Build Coastguard Worker 
152*89c4ff92SAndroid Build Coastguard Worker     InputTensors inputTensors;
153*89c4ff92SAndroid Build Coastguard Worker     inputTensors.reserve(inputTensorData.size());
154*89c4ff92SAndroid Build Coastguard Worker     for (auto&& it : inputTensorData)
155*89c4ff92SAndroid Build Coastguard Worker     {
156*89c4ff92SAndroid Build Coastguard Worker         inputTensors.push_back({it.first,
157*89c4ff92SAndroid Build Coastguard Worker                                 ConstTensor(runtime->GetInputTensorInfo(netId, it.first), it.second.data())});
158*89c4ff92SAndroid Build Coastguard Worker     }
159*89c4ff92SAndroid Build Coastguard Worker     OutputTensors outputTensors;
160*89c4ff92SAndroid Build Coastguard Worker     outputTensors.reserve(expectedOutputData.size());
161*89c4ff92SAndroid Build Coastguard Worker     std::map<int, std::vector<TOutput>> outputStorage;
162*89c4ff92SAndroid Build Coastguard Worker     for (auto&& it : expectedOutputData)
163*89c4ff92SAndroid Build Coastguard Worker     {
164*89c4ff92SAndroid Build Coastguard Worker         std::vector<TOutput> out(it.second.size());
165*89c4ff92SAndroid Build Coastguard Worker         outputStorage.emplace(it.first, out);
166*89c4ff92SAndroid Build Coastguard Worker         outputTensors.push_back({it.first,
167*89c4ff92SAndroid Build Coastguard Worker                                  Tensor(runtime->GetOutputTensorInfo(netId, it.first),
168*89c4ff92SAndroid Build Coastguard Worker                                                outputStorage.at(it.first).data())});
169*89c4ff92SAndroid Build Coastguard Worker     }
170*89c4ff92SAndroid Build Coastguard Worker 
171*89c4ff92SAndroid Build Coastguard Worker     // Does the inference.
172*89c4ff92SAndroid Build Coastguard Worker     runtime->EnqueueWorkload(netId, inputTensors, outputTensors);
173*89c4ff92SAndroid Build Coastguard Worker 
174*89c4ff92SAndroid Build Coastguard Worker     // Checks the results.
175*89c4ff92SAndroid Build Coastguard Worker     for (auto&& it : expectedOutputData)
176*89c4ff92SAndroid Build Coastguard Worker     {
177*89c4ff92SAndroid Build Coastguard Worker         std::vector<TOutput> out = outputStorage.at(it.first);
178*89c4ff92SAndroid Build Coastguard Worker         for (unsigned int i = 0; i < out.size(); ++i)
179*89c4ff92SAndroid Build Coastguard Worker         {
180*89c4ff92SAndroid Build Coastguard Worker             CHECK_MESSAGE(Compare<ArmnnOType>(it.second[i], out[i], tolerance) == true,
181*89c4ff92SAndroid Build Coastguard Worker                     "Actual output: " << out[i] << ". Expected output:" << it.second[i]);
182*89c4ff92SAndroid Build Coastguard Worker 
183*89c4ff92SAndroid Build Coastguard Worker         }
184*89c4ff92SAndroid Build Coastguard Worker     }
185*89c4ff92SAndroid Build Coastguard Worker }
186*89c4ff92SAndroid Build Coastguard Worker 
ImportNonAlignedInputPointerTest(std::vector<BackendId> backends)187*89c4ff92SAndroid Build Coastguard Worker inline void ImportNonAlignedInputPointerTest(std::vector<BackendId> backends)
188*89c4ff92SAndroid Build Coastguard Worker {
189*89c4ff92SAndroid Build Coastguard Worker     using namespace armnn;
190*89c4ff92SAndroid Build Coastguard Worker 
191*89c4ff92SAndroid Build Coastguard Worker     // Create runtime in which test will run
192*89c4ff92SAndroid Build Coastguard Worker     IRuntime::CreationOptions options;
193*89c4ff92SAndroid Build Coastguard Worker     IRuntimePtr runtime(armnn::IRuntime::Create(options));
194*89c4ff92SAndroid Build Coastguard Worker 
195*89c4ff92SAndroid Build Coastguard Worker     // build up the structure of the network
196*89c4ff92SAndroid Build Coastguard Worker     INetworkPtr net(INetwork::Create());
197*89c4ff92SAndroid Build Coastguard Worker 
198*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* input = net->AddInputLayer(0);
199*89c4ff92SAndroid Build Coastguard Worker 
200*89c4ff92SAndroid Build Coastguard Worker     ActivationDescriptor descriptor;
201*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_Function = ActivationFunction::Square;
202*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* pooling = net->AddActivationLayer(descriptor);
203*89c4ff92SAndroid Build Coastguard Worker 
204*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* output = net->AddOutputLayer(0);
205*89c4ff92SAndroid Build Coastguard Worker 
206*89c4ff92SAndroid Build Coastguard Worker     input->GetOutputSlot(0).Connect(pooling->GetInputSlot(0));
207*89c4ff92SAndroid Build Coastguard Worker     pooling->GetOutputSlot(0).Connect(output->GetInputSlot(0));
208*89c4ff92SAndroid Build Coastguard Worker 
209*89c4ff92SAndroid Build Coastguard Worker     input->GetOutputSlot(0).SetTensorInfo(TensorInfo({ 1, 1, 1, 4 }, DataType::Float32, 0.0f, 0, true));
210*89c4ff92SAndroid Build Coastguard Worker     pooling->GetOutputSlot(0).SetTensorInfo(TensorInfo({ 1, 1, 1, 4 }, DataType::Float32));
211*89c4ff92SAndroid Build Coastguard Worker 
212*89c4ff92SAndroid Build Coastguard Worker     // Optimize the network
213*89c4ff92SAndroid Build Coastguard Worker     OptimizerOptionsOpaque optimizedOptions;
214*89c4ff92SAndroid Build Coastguard Worker     optimizedOptions.SetImportEnabled(true);
215*89c4ff92SAndroid Build Coastguard Worker     IOptimizedNetworkPtr optNet = Optimize(*net, backends, runtime->GetDeviceSpec(), optimizedOptions);
216*89c4ff92SAndroid Build Coastguard Worker     CHECK(optNet);
217*89c4ff92SAndroid Build Coastguard Worker 
218*89c4ff92SAndroid Build Coastguard Worker     // Loads it into the runtime.
219*89c4ff92SAndroid Build Coastguard Worker     NetworkId netId;
220*89c4ff92SAndroid Build Coastguard Worker     std::string errorMessage;
221*89c4ff92SAndroid Build Coastguard Worker     // Enable Importing
222*89c4ff92SAndroid Build Coastguard Worker     INetworkProperties networkProperties(false, MemorySource::Malloc, MemorySource::Undefined);
223*89c4ff92SAndroid Build Coastguard Worker     armnn::Status loadingStatus = runtime->LoadNetwork(netId, std::move(optNet), errorMessage, networkProperties);
224*89c4ff92SAndroid Build Coastguard Worker     CHECK_MESSAGE(loadingStatus == Status::Success, errorMessage);
225*89c4ff92SAndroid Build Coastguard Worker 
226*89c4ff92SAndroid Build Coastguard Worker     // Creates structures for input & output
227*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputData
228*89c4ff92SAndroid Build Coastguard Worker     {
229*89c4ff92SAndroid Build Coastguard Worker         1.0f, 2.0f, 3.0f, 4.0f
230*89c4ff92SAndroid Build Coastguard Worker     };
231*89c4ff92SAndroid Build Coastguard Worker 
232*89c4ff92SAndroid Build Coastguard Worker     // Misaligned input
233*89c4ff92SAndroid Build Coastguard Worker     float* misalignedInputData = reinterpret_cast<float*>(reinterpret_cast<char*>(inputData.data()) + 1);
234*89c4ff92SAndroid Build Coastguard Worker 
235*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> outputData(4);
236*89c4ff92SAndroid Build Coastguard Worker 
237*89c4ff92SAndroid Build Coastguard Worker     // Aligned output
238*89c4ff92SAndroid Build Coastguard Worker     float* alignedOutputData = outputData.data();
239*89c4ff92SAndroid Build Coastguard Worker 
240*89c4ff92SAndroid Build Coastguard Worker     InputTensors inputTensors
241*89c4ff92SAndroid Build Coastguard Worker     {
242*89c4ff92SAndroid Build Coastguard Worker         {0,armnn::ConstTensor(runtime->GetInputTensorInfo(netId, 0), misalignedInputData)},
243*89c4ff92SAndroid Build Coastguard Worker     };
244*89c4ff92SAndroid Build Coastguard Worker     OutputTensors outputTensors
245*89c4ff92SAndroid Build Coastguard Worker     {
246*89c4ff92SAndroid Build Coastguard Worker         {0,armnn::Tensor(runtime->GetOutputTensorInfo(netId, 0), alignedOutputData)}
247*89c4ff92SAndroid Build Coastguard Worker     };
248*89c4ff92SAndroid Build Coastguard Worker 
249*89c4ff92SAndroid Build Coastguard Worker     runtime->GetProfiler(netId)->EnableProfiling(true);
250*89c4ff92SAndroid Build Coastguard Worker 
251*89c4ff92SAndroid Build Coastguard Worker     // Do the inference and expect it to fail with a ImportMemoryException
252*89c4ff92SAndroid Build Coastguard Worker     CHECK_THROWS_AS(runtime->EnqueueWorkload(netId, inputTensors, outputTensors), MemoryImportException);
253*89c4ff92SAndroid Build Coastguard Worker }
254*89c4ff92SAndroid Build Coastguard Worker 
ExportNonAlignedOutputPointerTest(std::vector<BackendId> backends)255*89c4ff92SAndroid Build Coastguard Worker inline void ExportNonAlignedOutputPointerTest(std::vector<BackendId> backends)
256*89c4ff92SAndroid Build Coastguard Worker {
257*89c4ff92SAndroid Build Coastguard Worker     using namespace armnn;
258*89c4ff92SAndroid Build Coastguard Worker 
259*89c4ff92SAndroid Build Coastguard Worker     // Create runtime in which test will run
260*89c4ff92SAndroid Build Coastguard Worker     IRuntime::CreationOptions options;
261*89c4ff92SAndroid Build Coastguard Worker     IRuntimePtr runtime(armnn::IRuntime::Create(options));
262*89c4ff92SAndroid Build Coastguard Worker 
263*89c4ff92SAndroid Build Coastguard Worker     // build up the structure of the network
264*89c4ff92SAndroid Build Coastguard Worker     INetworkPtr net(INetwork::Create());
265*89c4ff92SAndroid Build Coastguard Worker 
266*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* input = net->AddInputLayer(0);
267*89c4ff92SAndroid Build Coastguard Worker 
268*89c4ff92SAndroid Build Coastguard Worker     ActivationDescriptor descriptor;
269*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_Function = ActivationFunction::Square;
270*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* pooling = net->AddActivationLayer(descriptor);
271*89c4ff92SAndroid Build Coastguard Worker 
272*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* output = net->AddOutputLayer(0);
273*89c4ff92SAndroid Build Coastguard Worker 
274*89c4ff92SAndroid Build Coastguard Worker     input->GetOutputSlot(0).Connect(pooling->GetInputSlot(0));
275*89c4ff92SAndroid Build Coastguard Worker     pooling->GetOutputSlot(0).Connect(output->GetInputSlot(0));
276*89c4ff92SAndroid Build Coastguard Worker 
277*89c4ff92SAndroid Build Coastguard Worker     input->GetOutputSlot(0).SetTensorInfo(TensorInfo({ 1, 1, 1, 4 }, DataType::Float32, 0.0f, 0, true));
278*89c4ff92SAndroid Build Coastguard Worker     pooling->GetOutputSlot(0).SetTensorInfo(TensorInfo({ 1, 1, 1, 4 }, DataType::Float32));
279*89c4ff92SAndroid Build Coastguard Worker 
280*89c4ff92SAndroid Build Coastguard Worker     // Optimize the network
281*89c4ff92SAndroid Build Coastguard Worker     OptimizerOptionsOpaque optimizedOptions;
282*89c4ff92SAndroid Build Coastguard Worker     optimizedOptions.SetImportEnabled(true);
283*89c4ff92SAndroid Build Coastguard Worker     optimizedOptions.SetExportEnabled(true);
284*89c4ff92SAndroid Build Coastguard Worker     IOptimizedNetworkPtr optNet = Optimize(*net, backends, runtime->GetDeviceSpec(), optimizedOptions);
285*89c4ff92SAndroid Build Coastguard Worker     CHECK(optNet);
286*89c4ff92SAndroid Build Coastguard Worker 
287*89c4ff92SAndroid Build Coastguard Worker     // Loads it into the runtime.
288*89c4ff92SAndroid Build Coastguard Worker     NetworkId netId;
289*89c4ff92SAndroid Build Coastguard Worker     std::string errorMessage;
290*89c4ff92SAndroid Build Coastguard Worker     // Enable Importing and Exporting
291*89c4ff92SAndroid Build Coastguard Worker     INetworkProperties networkProperties(false, MemorySource::Malloc, MemorySource::Malloc);
292*89c4ff92SAndroid Build Coastguard Worker     armnn::Status loadingStatus = runtime->LoadNetwork(netId, std::move(optNet), errorMessage, networkProperties);
293*89c4ff92SAndroid Build Coastguard Worker     CHECK_MESSAGE(loadingStatus == Status::Success, errorMessage);
294*89c4ff92SAndroid Build Coastguard Worker 
295*89c4ff92SAndroid Build Coastguard Worker     // Creates structures for input & output
296*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputData
297*89c4ff92SAndroid Build Coastguard Worker     {
298*89c4ff92SAndroid Build Coastguard Worker         1.0f, 2.0f, 3.0f, 4.0f, 5.0f
299*89c4ff92SAndroid Build Coastguard Worker     };
300*89c4ff92SAndroid Build Coastguard Worker 
301*89c4ff92SAndroid Build Coastguard Worker     // Aligned input
302*89c4ff92SAndroid Build Coastguard Worker     float* alignedInputData = inputData.data();
303*89c4ff92SAndroid Build Coastguard Worker 
304*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> outputData(5);
305*89c4ff92SAndroid Build Coastguard Worker 
306*89c4ff92SAndroid Build Coastguard Worker     // Misaligned output
307*89c4ff92SAndroid Build Coastguard Worker     float* misalignedOutputData = reinterpret_cast<float*>(reinterpret_cast<char*>(outputData.data()) + 1);
308*89c4ff92SAndroid Build Coastguard Worker 
309*89c4ff92SAndroid Build Coastguard Worker     InputTensors inputTensors
310*89c4ff92SAndroid Build Coastguard Worker     {
311*89c4ff92SAndroid Build Coastguard Worker         {0,armnn::ConstTensor(runtime->GetInputTensorInfo(netId, 0), alignedInputData)},
312*89c4ff92SAndroid Build Coastguard Worker     };
313*89c4ff92SAndroid Build Coastguard Worker     OutputTensors outputTensors
314*89c4ff92SAndroid Build Coastguard Worker     {
315*89c4ff92SAndroid Build Coastguard Worker         {0,armnn::Tensor(runtime->GetOutputTensorInfo(netId, 0), misalignedOutputData)}
316*89c4ff92SAndroid Build Coastguard Worker     };
317*89c4ff92SAndroid Build Coastguard Worker 
318*89c4ff92SAndroid Build Coastguard Worker     // Do the inference and expect it to fail with a ExportMemoryException
319*89c4ff92SAndroid Build Coastguard Worker     if (backends[0] == Compute::CpuAcc)
320*89c4ff92SAndroid Build Coastguard Worker     {
321*89c4ff92SAndroid Build Coastguard Worker         // For CpuAcc the NeonTensorHandle will throw its own exception on misaligned memory
322*89c4ff92SAndroid Build Coastguard Worker         CHECK_THROWS_AS(runtime->EnqueueWorkload(netId, inputTensors, outputTensors), MemoryImportException);
323*89c4ff92SAndroid Build Coastguard Worker     }
324*89c4ff92SAndroid Build Coastguard Worker     else
325*89c4ff92SAndroid Build Coastguard Worker     {
326*89c4ff92SAndroid Build Coastguard Worker         CHECK_THROWS_AS(runtime->EnqueueWorkload(netId, inputTensors, outputTensors), MemoryExportException);
327*89c4ff92SAndroid Build Coastguard Worker     }
328*89c4ff92SAndroid Build Coastguard Worker }
329*89c4ff92SAndroid Build Coastguard Worker 
ImportAlignedPointerTest(std::vector<BackendId> backends)330*89c4ff92SAndroid Build Coastguard Worker inline void ImportAlignedPointerTest(std::vector<BackendId> backends)
331*89c4ff92SAndroid Build Coastguard Worker {
332*89c4ff92SAndroid Build Coastguard Worker     using namespace armnn;
333*89c4ff92SAndroid Build Coastguard Worker 
334*89c4ff92SAndroid Build Coastguard Worker     // Create runtime in which test will run
335*89c4ff92SAndroid Build Coastguard Worker     IRuntime::CreationOptions options;
336*89c4ff92SAndroid Build Coastguard Worker     IRuntimePtr runtime(armnn::IRuntime::Create(options));
337*89c4ff92SAndroid Build Coastguard Worker 
338*89c4ff92SAndroid Build Coastguard Worker     // build up the structure of the network
339*89c4ff92SAndroid Build Coastguard Worker     INetworkPtr net(INetwork::Create());
340*89c4ff92SAndroid Build Coastguard Worker 
341*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* input = net->AddInputLayer(0);
342*89c4ff92SAndroid Build Coastguard Worker 
343*89c4ff92SAndroid Build Coastguard Worker     ActivationDescriptor descriptor;
344*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_Function = ActivationFunction::Square;
345*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* pooling = net->AddActivationLayer(descriptor);
346*89c4ff92SAndroid Build Coastguard Worker 
347*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* output = net->AddOutputLayer(0);
348*89c4ff92SAndroid Build Coastguard Worker 
349*89c4ff92SAndroid Build Coastguard Worker     input->GetOutputSlot(0).Connect(pooling->GetInputSlot(0));
350*89c4ff92SAndroid Build Coastguard Worker     pooling->GetOutputSlot(0).Connect(output->GetInputSlot(0));
351*89c4ff92SAndroid Build Coastguard Worker 
352*89c4ff92SAndroid Build Coastguard Worker     input->GetOutputSlot(0).SetTensorInfo(TensorInfo({ 1, 1, 1, 4 }, DataType::Float32, 0.0f, 0, true));
353*89c4ff92SAndroid Build Coastguard Worker     pooling->GetOutputSlot(0).SetTensorInfo(TensorInfo({ 1, 1, 1, 4 }, DataType::Float32));
354*89c4ff92SAndroid Build Coastguard Worker 
355*89c4ff92SAndroid Build Coastguard Worker     // Optimize the network
356*89c4ff92SAndroid Build Coastguard Worker     OptimizerOptionsOpaque optimizedOptions;
357*89c4ff92SAndroid Build Coastguard Worker     optimizedOptions.SetImportEnabled(true);
358*89c4ff92SAndroid Build Coastguard Worker     optimizedOptions.SetExportEnabled(true);
359*89c4ff92SAndroid Build Coastguard Worker     IOptimizedNetworkPtr optNet = Optimize(*net, backends, runtime->GetDeviceSpec(), optimizedOptions);
360*89c4ff92SAndroid Build Coastguard Worker     CHECK(optNet);
361*89c4ff92SAndroid Build Coastguard Worker 
362*89c4ff92SAndroid Build Coastguard Worker     // Loads it into the runtime.
363*89c4ff92SAndroid Build Coastguard Worker     NetworkId netId;
364*89c4ff92SAndroid Build Coastguard Worker     std::string errorMessage;
365*89c4ff92SAndroid Build Coastguard Worker     // Enable Importing
366*89c4ff92SAndroid Build Coastguard Worker     INetworkProperties networkProperties(false, MemorySource::Malloc, MemorySource::Malloc);
367*89c4ff92SAndroid Build Coastguard Worker     armnn::Status loadingStatus = runtime->LoadNetwork(netId, std::move(optNet), errorMessage, networkProperties);
368*89c4ff92SAndroid Build Coastguard Worker     CHECK_MESSAGE(loadingStatus == Status::Success, errorMessage);
369*89c4ff92SAndroid Build Coastguard Worker 
370*89c4ff92SAndroid Build Coastguard Worker     // Creates structures for input & output
371*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputData
372*89c4ff92SAndroid Build Coastguard Worker     {
373*89c4ff92SAndroid Build Coastguard Worker         1.0f, 2.0f, 3.0f, 4.0f
374*89c4ff92SAndroid Build Coastguard Worker     };
375*89c4ff92SAndroid Build Coastguard Worker 
376*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> outputData(4);
377*89c4ff92SAndroid Build Coastguard Worker 
378*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expectedOutput
379*89c4ff92SAndroid Build Coastguard Worker     {
380*89c4ff92SAndroid Build Coastguard Worker         1.0f, 4.0f, 9.0f, 16.0f
381*89c4ff92SAndroid Build Coastguard Worker     };
382*89c4ff92SAndroid Build Coastguard Worker 
383*89c4ff92SAndroid Build Coastguard Worker     InputTensors inputTensors
384*89c4ff92SAndroid Build Coastguard Worker     {
385*89c4ff92SAndroid Build Coastguard Worker         {0,armnn::ConstTensor(runtime->GetInputTensorInfo(netId, 0), inputData.data())},
386*89c4ff92SAndroid Build Coastguard Worker     };
387*89c4ff92SAndroid Build Coastguard Worker     OutputTensors outputTensors
388*89c4ff92SAndroid Build Coastguard Worker     {
389*89c4ff92SAndroid Build Coastguard Worker         {0,armnn::Tensor(runtime->GetOutputTensorInfo(netId, 0), outputData.data())}
390*89c4ff92SAndroid Build Coastguard Worker     };
391*89c4ff92SAndroid Build Coastguard Worker 
392*89c4ff92SAndroid Build Coastguard Worker     runtime->GetProfiler(netId)->EnableProfiling(true);
393*89c4ff92SAndroid Build Coastguard Worker 
394*89c4ff92SAndroid Build Coastguard Worker     // Do the inference
395*89c4ff92SAndroid Build Coastguard Worker     runtime->EnqueueWorkload(netId, inputTensors, outputTensors);
396*89c4ff92SAndroid Build Coastguard Worker 
397*89c4ff92SAndroid Build Coastguard Worker     // Retrieve the Profiler.Print() output to get the workload execution
398*89c4ff92SAndroid Build Coastguard Worker     ProfilerManager& profilerManager = armnn::ProfilerManager::GetInstance();
399*89c4ff92SAndroid Build Coastguard Worker     std::stringstream ss;
400*89c4ff92SAndroid Build Coastguard Worker     profilerManager.GetProfiler()->Print(ss);
401*89c4ff92SAndroid Build Coastguard Worker     std::string dump = ss.str();
402*89c4ff92SAndroid Build Coastguard Worker 
403*89c4ff92SAndroid Build Coastguard Worker     // Contains ActivationWorkload
404*89c4ff92SAndroid Build Coastguard Worker     std::size_t found = dump.find("ActivationWorkload");
405*89c4ff92SAndroid Build Coastguard Worker     CHECK(found != std::string::npos);
406*89c4ff92SAndroid Build Coastguard Worker 
407*89c4ff92SAndroid Build Coastguard Worker     // Contains SyncMemGeneric
408*89c4ff92SAndroid Build Coastguard Worker     found = dump.find("SyncMemGeneric");
409*89c4ff92SAndroid Build Coastguard Worker     CHECK(found != std::string::npos);
410*89c4ff92SAndroid Build Coastguard Worker 
411*89c4ff92SAndroid Build Coastguard Worker     // Does not contain CopyMemGeneric
412*89c4ff92SAndroid Build Coastguard Worker     found = dump.find("CopyMemGeneric");
413*89c4ff92SAndroid Build Coastguard Worker     CHECK(found == std::string::npos);
414*89c4ff92SAndroid Build Coastguard Worker 
415*89c4ff92SAndroid Build Coastguard Worker     // Check output is as expected
416*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputData == expectedOutput);
417*89c4ff92SAndroid Build Coastguard Worker }
418*89c4ff92SAndroid Build Coastguard Worker 
ImportOnlyWorkload(std::vector<BackendId> backends)419*89c4ff92SAndroid Build Coastguard Worker inline void ImportOnlyWorkload(std::vector<BackendId> backends)
420*89c4ff92SAndroid Build Coastguard Worker {
421*89c4ff92SAndroid Build Coastguard Worker     using namespace armnn;
422*89c4ff92SAndroid Build Coastguard Worker 
423*89c4ff92SAndroid Build Coastguard Worker     IRuntime::CreationOptions options;
424*89c4ff92SAndroid Build Coastguard Worker     IRuntimePtr runtime(IRuntime::Create(options));
425*89c4ff92SAndroid Build Coastguard Worker 
426*89c4ff92SAndroid Build Coastguard Worker     // Builds up the structure of the network.
427*89c4ff92SAndroid Build Coastguard Worker     INetworkPtr net(INetwork::Create());
428*89c4ff92SAndroid Build Coastguard Worker 
429*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* input = net->AddInputLayer(0);
430*89c4ff92SAndroid Build Coastguard Worker 
431*89c4ff92SAndroid Build Coastguard Worker     ActivationDescriptor descriptor;
432*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_Function = ActivationFunction::Square;
433*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* pooling = net->AddActivationLayer(descriptor);
434*89c4ff92SAndroid Build Coastguard Worker 
435*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* output = net->AddOutputLayer(0);
436*89c4ff92SAndroid Build Coastguard Worker 
437*89c4ff92SAndroid Build Coastguard Worker     input->GetOutputSlot(0).Connect(pooling->GetInputSlot(0));
438*89c4ff92SAndroid Build Coastguard Worker     pooling->GetOutputSlot(0).Connect(output->GetInputSlot(0));
439*89c4ff92SAndroid Build Coastguard Worker 
440*89c4ff92SAndroid Build Coastguard Worker     input->GetOutputSlot(0).SetTensorInfo(TensorInfo({ 1, 1, 1, 4 }, DataType::Float32, 0.0f, 0, true));
441*89c4ff92SAndroid Build Coastguard Worker     pooling->GetOutputSlot(0).SetTensorInfo(TensorInfo({ 1, 1, 1, 4 }, DataType::Float32));
442*89c4ff92SAndroid Build Coastguard Worker 
443*89c4ff92SAndroid Build Coastguard Worker     // optimize the network
444*89c4ff92SAndroid Build Coastguard Worker     OptimizerOptionsOpaque optimizedOptions;
445*89c4ff92SAndroid Build Coastguard Worker     optimizedOptions.SetImportEnabled(true);
446*89c4ff92SAndroid Build Coastguard Worker     IOptimizedNetworkPtr optNet = Optimize(*net, backends, runtime->GetDeviceSpec(), optimizedOptions);
447*89c4ff92SAndroid Build Coastguard Worker 
448*89c4ff92SAndroid Build Coastguard Worker     INFO("Load Network");
449*89c4ff92SAndroid Build Coastguard Worker     // Load it into the runtime. It should pass.
450*89c4ff92SAndroid Build Coastguard Worker     NetworkId netId;
451*89c4ff92SAndroid Build Coastguard Worker     std::string errorMessage;
452*89c4ff92SAndroid Build Coastguard Worker     INetworkProperties networkProperties(false, MemorySource::Malloc, MemorySource::Undefined);
453*89c4ff92SAndroid Build Coastguard Worker     armnn::Status loadingStatus = runtime->LoadNetwork(netId, std::move(optNet), errorMessage, networkProperties);
454*89c4ff92SAndroid Build Coastguard Worker     CHECK_MESSAGE(loadingStatus == Status::Success, errorMessage);
455*89c4ff92SAndroid Build Coastguard Worker 
456*89c4ff92SAndroid Build Coastguard Worker     INFO("Generate Data");
457*89c4ff92SAndroid Build Coastguard Worker     // Creates structures for input & output
458*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputData
459*89c4ff92SAndroid Build Coastguard Worker     {
460*89c4ff92SAndroid Build Coastguard Worker         1.0f, 2.0f, 3.0f, 4.0f
461*89c4ff92SAndroid Build Coastguard Worker     };
462*89c4ff92SAndroid Build Coastguard Worker 
463*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> outputData(4);
464*89c4ff92SAndroid Build Coastguard Worker 
465*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expectedOutput
466*89c4ff92SAndroid Build Coastguard Worker     {
467*89c4ff92SAndroid Build Coastguard Worker          1.0f, 4.0f, 9.0f, 16.0f
468*89c4ff92SAndroid Build Coastguard Worker     };
469*89c4ff92SAndroid Build Coastguard Worker 
470*89c4ff92SAndroid Build Coastguard Worker     INFO("Create Inference");
471*89c4ff92SAndroid Build Coastguard Worker 
472*89c4ff92SAndroid Build Coastguard Worker     InputTensors inputTensors
473*89c4ff92SAndroid Build Coastguard Worker     {
474*89c4ff92SAndroid Build Coastguard Worker         {0,armnn::ConstTensor(runtime->GetInputTensorInfo(netId, 0), inputData.data())},
475*89c4ff92SAndroid Build Coastguard Worker     };
476*89c4ff92SAndroid Build Coastguard Worker     OutputTensors outputTensors
477*89c4ff92SAndroid Build Coastguard Worker     {
478*89c4ff92SAndroid Build Coastguard Worker         {0,armnn::Tensor(runtime->GetOutputTensorInfo(netId, 0), outputData.data())}
479*89c4ff92SAndroid Build Coastguard Worker     };
480*89c4ff92SAndroid Build Coastguard Worker 
481*89c4ff92SAndroid Build Coastguard Worker     INFO("Get Profiler");
482*89c4ff92SAndroid Build Coastguard Worker     runtime->GetProfiler(netId)->EnableProfiling(true);
483*89c4ff92SAndroid Build Coastguard Worker 
484*89c4ff92SAndroid Build Coastguard Worker     INFO("Run Inference");
485*89c4ff92SAndroid Build Coastguard Worker     // Do the inference
486*89c4ff92SAndroid Build Coastguard Worker     runtime->EnqueueWorkload(netId, inputTensors, outputTensors);
487*89c4ff92SAndroid Build Coastguard Worker 
488*89c4ff92SAndroid Build Coastguard Worker     INFO("Print Profiler");
489*89c4ff92SAndroid Build Coastguard Worker     // Retrieve the Profiler.Print() output to get the workload execution
490*89c4ff92SAndroid Build Coastguard Worker     ProfilerManager& profilerManager = armnn::ProfilerManager::GetInstance();
491*89c4ff92SAndroid Build Coastguard Worker     std::stringstream ss;
492*89c4ff92SAndroid Build Coastguard Worker     profilerManager.GetProfiler()->Print(ss);
493*89c4ff92SAndroid Build Coastguard Worker     std::string dump = ss.str();
494*89c4ff92SAndroid Build Coastguard Worker 
495*89c4ff92SAndroid Build Coastguard Worker     // Check there are no SyncMemGeneric workloads as we didn't export
496*89c4ff92SAndroid Build Coastguard Worker     INFO("Find SyncMemGeneric");
497*89c4ff92SAndroid Build Coastguard Worker     int count = SubStringCounter(dump, "SyncMemGeneric");
498*89c4ff92SAndroid Build Coastguard Worker     CHECK(count == 0);
499*89c4ff92SAndroid Build Coastguard Worker 
500*89c4ff92SAndroid Build Coastguard Worker     // Should only be 1 CopyMemGeneric for the output as we imported
501*89c4ff92SAndroid Build Coastguard Worker     INFO("Find CopyMemGeneric");
502*89c4ff92SAndroid Build Coastguard Worker     count = SubStringCounter(dump, "CopyMemGeneric");
503*89c4ff92SAndroid Build Coastguard Worker     CHECK(count == 1);
504*89c4ff92SAndroid Build Coastguard Worker 
505*89c4ff92SAndroid Build Coastguard Worker     // Check the output is correct
506*89c4ff92SAndroid Build Coastguard Worker     CHECK(std::equal(outputData.begin(), outputData.end(), expectedOutput.begin(), expectedOutput.end()));
507*89c4ff92SAndroid Build Coastguard Worker }
508*89c4ff92SAndroid Build Coastguard Worker 
ExportOnlyWorkload(std::vector<BackendId> backends)509*89c4ff92SAndroid Build Coastguard Worker inline void ExportOnlyWorkload(std::vector<BackendId> backends)
510*89c4ff92SAndroid Build Coastguard Worker {
511*89c4ff92SAndroid Build Coastguard Worker     using namespace armnn;
512*89c4ff92SAndroid Build Coastguard Worker 
513*89c4ff92SAndroid Build Coastguard Worker     IRuntime::CreationOptions options;
514*89c4ff92SAndroid Build Coastguard Worker     IRuntimePtr runtime(IRuntime::Create(options));
515*89c4ff92SAndroid Build Coastguard Worker 
516*89c4ff92SAndroid Build Coastguard Worker     // Builds up the structure of the network.
517*89c4ff92SAndroid Build Coastguard Worker     INetworkPtr net(INetwork::Create());
518*89c4ff92SAndroid Build Coastguard Worker 
519*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* input = net->AddInputLayer(0);
520*89c4ff92SAndroid Build Coastguard Worker 
521*89c4ff92SAndroid Build Coastguard Worker     ActivationDescriptor descriptor;
522*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_Function = ActivationFunction::Square;
523*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* pooling = net->AddActivationLayer(descriptor);
524*89c4ff92SAndroid Build Coastguard Worker 
525*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* output = net->AddOutputLayer(0);
526*89c4ff92SAndroid Build Coastguard Worker 
527*89c4ff92SAndroid Build Coastguard Worker     input->GetOutputSlot(0).Connect(pooling->GetInputSlot(0));
528*89c4ff92SAndroid Build Coastguard Worker     pooling->GetOutputSlot(0).Connect(output->GetInputSlot(0));
529*89c4ff92SAndroid Build Coastguard Worker 
530*89c4ff92SAndroid Build Coastguard Worker     input->GetOutputSlot(0).SetTensorInfo(TensorInfo({ 1, 1, 1, 4 }, DataType::Float32, 0.0f, 0, true));
531*89c4ff92SAndroid Build Coastguard Worker     pooling->GetOutputSlot(0).SetTensorInfo(TensorInfo({ 1, 1, 1, 4 }, DataType::Float32));
532*89c4ff92SAndroid Build Coastguard Worker 
533*89c4ff92SAndroid Build Coastguard Worker     // optimize the network
534*89c4ff92SAndroid Build Coastguard Worker     OptimizerOptionsOpaque optimizedOptions;
535*89c4ff92SAndroid Build Coastguard Worker     optimizedOptions.SetExportEnabled(true);
536*89c4ff92SAndroid Build Coastguard Worker     IOptimizedNetworkPtr optNet = Optimize(*net, backends, runtime->GetDeviceSpec(), optimizedOptions);
537*89c4ff92SAndroid Build Coastguard Worker 
538*89c4ff92SAndroid Build Coastguard Worker     INFO("Load Network");
539*89c4ff92SAndroid Build Coastguard Worker     // Load it into the runtime. It should pass.
540*89c4ff92SAndroid Build Coastguard Worker     NetworkId netId;
541*89c4ff92SAndroid Build Coastguard Worker     std::string errorMessage;
542*89c4ff92SAndroid Build Coastguard Worker     INetworkProperties networkProperties(false, MemorySource::Undefined, MemorySource::Malloc);
543*89c4ff92SAndroid Build Coastguard Worker     armnn::Status loadingStatus = runtime->LoadNetwork(netId, std::move(optNet), errorMessage, networkProperties);
544*89c4ff92SAndroid Build Coastguard Worker     CHECK_MESSAGE(loadingStatus == Status::Success, errorMessage);
545*89c4ff92SAndroid Build Coastguard Worker 
546*89c4ff92SAndroid Build Coastguard Worker     INFO("Generate Data");
547*89c4ff92SAndroid Build Coastguard Worker     // Creates structures for input & output
548*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputData
549*89c4ff92SAndroid Build Coastguard Worker     {
550*89c4ff92SAndroid Build Coastguard Worker         1.0f, 2.0f, 3.0f, 4.0f
551*89c4ff92SAndroid Build Coastguard Worker     };
552*89c4ff92SAndroid Build Coastguard Worker 
553*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> outputData(4);
554*89c4ff92SAndroid Build Coastguard Worker 
555*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expectedOutput
556*89c4ff92SAndroid Build Coastguard Worker     {
557*89c4ff92SAndroid Build Coastguard Worker          1.0f, 4.0f, 9.0f, 16.0f
558*89c4ff92SAndroid Build Coastguard Worker     };
559*89c4ff92SAndroid Build Coastguard Worker 
560*89c4ff92SAndroid Build Coastguard Worker     INFO("Create Inference");
561*89c4ff92SAndroid Build Coastguard Worker 
562*89c4ff92SAndroid Build Coastguard Worker     InputTensors inputTensors
563*89c4ff92SAndroid Build Coastguard Worker     {
564*89c4ff92SAndroid Build Coastguard Worker         {0,armnn::ConstTensor(runtime->GetInputTensorInfo(netId, 0), inputData.data())},
565*89c4ff92SAndroid Build Coastguard Worker     };
566*89c4ff92SAndroid Build Coastguard Worker     OutputTensors outputTensors
567*89c4ff92SAndroid Build Coastguard Worker     {
568*89c4ff92SAndroid Build Coastguard Worker         {0,armnn::Tensor(runtime->GetOutputTensorInfo(netId, 0), outputData.data())}
569*89c4ff92SAndroid Build Coastguard Worker     };
570*89c4ff92SAndroid Build Coastguard Worker 
571*89c4ff92SAndroid Build Coastguard Worker     INFO("Get Profiler");
572*89c4ff92SAndroid Build Coastguard Worker     runtime->GetProfiler(netId)->EnableProfiling(true);
573*89c4ff92SAndroid Build Coastguard Worker 
574*89c4ff92SAndroid Build Coastguard Worker     INFO("Run Inference");
575*89c4ff92SAndroid Build Coastguard Worker     // Do the inference
576*89c4ff92SAndroid Build Coastguard Worker     runtime->EnqueueWorkload(netId, inputTensors, outputTensors);
577*89c4ff92SAndroid Build Coastguard Worker 
578*89c4ff92SAndroid Build Coastguard Worker     INFO("Print Profiler");
579*89c4ff92SAndroid Build Coastguard Worker     // Retrieve the Profiler.Print() output to get the workload execution
580*89c4ff92SAndroid Build Coastguard Worker     ProfilerManager& profilerManager = armnn::ProfilerManager::GetInstance();
581*89c4ff92SAndroid Build Coastguard Worker     std::stringstream ss;
582*89c4ff92SAndroid Build Coastguard Worker     profilerManager.GetProfiler()->Print(ss);
583*89c4ff92SAndroid Build Coastguard Worker     std::string dump = ss.str();
584*89c4ff92SAndroid Build Coastguard Worker 
585*89c4ff92SAndroid Build Coastguard Worker     // Check there is a SyncMemGeneric workload as we exported
586*89c4ff92SAndroid Build Coastguard Worker     INFO("Find SyncMemGeneric");
587*89c4ff92SAndroid Build Coastguard Worker     int count = SubStringCounter(dump, "SyncMemGeneric");
588*89c4ff92SAndroid Build Coastguard Worker     CHECK(count == 1);
589*89c4ff92SAndroid Build Coastguard Worker 
590*89c4ff92SAndroid Build Coastguard Worker     // Should be 1 CopyMemGeneric for the output as we did not import
591*89c4ff92SAndroid Build Coastguard Worker     INFO("Find CopyMemGeneric");
592*89c4ff92SAndroid Build Coastguard Worker     count = SubStringCounter(dump, "CopyMemGeneric");
593*89c4ff92SAndroid Build Coastguard Worker     CHECK(count == 1);
594*89c4ff92SAndroid Build Coastguard Worker 
595*89c4ff92SAndroid Build Coastguard Worker     // Check the output is correct
596*89c4ff92SAndroid Build Coastguard Worker     CHECK(std::equal(outputData.begin(), outputData.end(), expectedOutput.begin(), expectedOutput.end()));
597*89c4ff92SAndroid Build Coastguard Worker }
598*89c4ff92SAndroid Build Coastguard Worker 
ImportAndExportWorkload(std::vector<BackendId> backends)599*89c4ff92SAndroid Build Coastguard Worker inline void ImportAndExportWorkload(std::vector<BackendId> backends)
600*89c4ff92SAndroid Build Coastguard Worker {
601*89c4ff92SAndroid Build Coastguard Worker     using namespace armnn;
602*89c4ff92SAndroid Build Coastguard Worker 
603*89c4ff92SAndroid Build Coastguard Worker     IRuntime::CreationOptions options;
604*89c4ff92SAndroid Build Coastguard Worker     IRuntimePtr runtime(IRuntime::Create(options));
605*89c4ff92SAndroid Build Coastguard Worker 
606*89c4ff92SAndroid Build Coastguard Worker     // Builds up the structure of the network.
607*89c4ff92SAndroid Build Coastguard Worker     INetworkPtr net(INetwork::Create());
608*89c4ff92SAndroid Build Coastguard Worker 
609*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* input = net->AddInputLayer(0);
610*89c4ff92SAndroid Build Coastguard Worker 
611*89c4ff92SAndroid Build Coastguard Worker     ActivationDescriptor descriptor;
612*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_Function = ActivationFunction::Square;
613*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* pooling = net->AddActivationLayer(descriptor);
614*89c4ff92SAndroid Build Coastguard Worker 
615*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* output = net->AddOutputLayer(0);
616*89c4ff92SAndroid Build Coastguard Worker 
617*89c4ff92SAndroid Build Coastguard Worker     input->GetOutputSlot(0).Connect(pooling->GetInputSlot(0));
618*89c4ff92SAndroid Build Coastguard Worker     pooling->GetOutputSlot(0).Connect(output->GetInputSlot(0));
619*89c4ff92SAndroid Build Coastguard Worker 
620*89c4ff92SAndroid Build Coastguard Worker     input->GetOutputSlot(0).SetTensorInfo(TensorInfo({ 1, 1, 1, 4 }, DataType::Float32, 0.0f, 0, true));
621*89c4ff92SAndroid Build Coastguard Worker     pooling->GetOutputSlot(0).SetTensorInfo(TensorInfo({ 1, 1, 1, 4 }, DataType::Float32));
622*89c4ff92SAndroid Build Coastguard Worker 
623*89c4ff92SAndroid Build Coastguard Worker     OptimizerOptionsOpaque optimizedOptions;
624*89c4ff92SAndroid Build Coastguard Worker     optimizedOptions.SetImportEnabled(true);
625*89c4ff92SAndroid Build Coastguard Worker     optimizedOptions.SetExportEnabled(true);
626*89c4ff92SAndroid Build Coastguard Worker     IOptimizedNetworkPtr optNet = Optimize(*net, backends, runtime->GetDeviceSpec(), optimizedOptions);
627*89c4ff92SAndroid Build Coastguard Worker 
628*89c4ff92SAndroid Build Coastguard Worker     INFO("Load Network");
629*89c4ff92SAndroid Build Coastguard Worker     // Load it into the runtime. It should pass.
630*89c4ff92SAndroid Build Coastguard Worker     NetworkId netId;
631*89c4ff92SAndroid Build Coastguard Worker     std::string errorMessage;
632*89c4ff92SAndroid Build Coastguard Worker     INetworkProperties networkProperties(false, MemorySource::Malloc, MemorySource::Malloc);
633*89c4ff92SAndroid Build Coastguard Worker     armnn::Status loadingStatus = runtime->LoadNetwork(netId, std::move(optNet), errorMessage, networkProperties);
634*89c4ff92SAndroid Build Coastguard Worker     CHECK_MESSAGE(loadingStatus == Status::Success, errorMessage);
635*89c4ff92SAndroid Build Coastguard Worker 
636*89c4ff92SAndroid Build Coastguard Worker     INFO("Generate Data");
637*89c4ff92SAndroid Build Coastguard Worker     // Creates structures for input & output
638*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputData
639*89c4ff92SAndroid Build Coastguard Worker     {
640*89c4ff92SAndroid Build Coastguard Worker         1.0f, 2.0f, 3.0f, 4.0f
641*89c4ff92SAndroid Build Coastguard Worker     };
642*89c4ff92SAndroid Build Coastguard Worker 
643*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> outputData(4);
644*89c4ff92SAndroid Build Coastguard Worker 
645*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expectedOutput
646*89c4ff92SAndroid Build Coastguard Worker     {
647*89c4ff92SAndroid Build Coastguard Worker          1.0f, 4.0f, 9.0f, 16.0f
648*89c4ff92SAndroid Build Coastguard Worker     };
649*89c4ff92SAndroid Build Coastguard Worker 
650*89c4ff92SAndroid Build Coastguard Worker     INFO("Create inference");
651*89c4ff92SAndroid Build Coastguard Worker 
652*89c4ff92SAndroid Build Coastguard Worker     InputTensors inputTensors
653*89c4ff92SAndroid Build Coastguard Worker     {
654*89c4ff92SAndroid Build Coastguard Worker         {0,armnn::ConstTensor(runtime->GetInputTensorInfo(netId, 0), inputData.data())},
655*89c4ff92SAndroid Build Coastguard Worker     };
656*89c4ff92SAndroid Build Coastguard Worker     OutputTensors outputTensors
657*89c4ff92SAndroid Build Coastguard Worker     {
658*89c4ff92SAndroid Build Coastguard Worker         {0,armnn::Tensor(runtime->GetOutputTensorInfo(netId, 0), outputData.data())}
659*89c4ff92SAndroid Build Coastguard Worker     };
660*89c4ff92SAndroid Build Coastguard Worker 
661*89c4ff92SAndroid Build Coastguard Worker     INFO("Get Profiler");
662*89c4ff92SAndroid Build Coastguard Worker     runtime->GetProfiler(netId)->EnableProfiling(true);
663*89c4ff92SAndroid Build Coastguard Worker 
664*89c4ff92SAndroid Build Coastguard Worker     INFO("Run Inference");
665*89c4ff92SAndroid Build Coastguard Worker     // Do the inference
666*89c4ff92SAndroid Build Coastguard Worker     runtime->EnqueueWorkload(netId, inputTensors, outputTensors);
667*89c4ff92SAndroid Build Coastguard Worker 
668*89c4ff92SAndroid Build Coastguard Worker     INFO("Print Profiler");
669*89c4ff92SAndroid Build Coastguard Worker     // Retrieve the Profiler.Print() output to get the workload execution
670*89c4ff92SAndroid Build Coastguard Worker     ProfilerManager& profilerManager = armnn::ProfilerManager::GetInstance();
671*89c4ff92SAndroid Build Coastguard Worker     std::stringstream ss;
672*89c4ff92SAndroid Build Coastguard Worker     profilerManager.GetProfiler()->Print(ss);
673*89c4ff92SAndroid Build Coastguard Worker     std::string dump = ss.str();
674*89c4ff92SAndroid Build Coastguard Worker 
675*89c4ff92SAndroid Build Coastguard Worker     // Check there is a SyncMemGeneric workload as we exported
676*89c4ff92SAndroid Build Coastguard Worker     INFO("Find SyncMemGeneric");
677*89c4ff92SAndroid Build Coastguard Worker     int count = SubStringCounter(dump, "SyncMemGeneric");
678*89c4ff92SAndroid Build Coastguard Worker     CHECK(count == 1);
679*89c4ff92SAndroid Build Coastguard Worker 
680*89c4ff92SAndroid Build Coastguard Worker     // Shouldn't be any CopyMemGeneric workloads
681*89c4ff92SAndroid Build Coastguard Worker     INFO("Find CopyMemGeneric");
682*89c4ff92SAndroid Build Coastguard Worker     count = SubStringCounter(dump, "CopyMemGeneric");
683*89c4ff92SAndroid Build Coastguard Worker     CHECK(count == 0);
684*89c4ff92SAndroid Build Coastguard Worker 
685*89c4ff92SAndroid Build Coastguard Worker     // Check the output is correct
686*89c4ff92SAndroid Build Coastguard Worker     CHECK(std::equal(outputData.begin(), outputData.end(), expectedOutput.begin(), expectedOutput.end()));
687*89c4ff92SAndroid Build Coastguard Worker }
688*89c4ff92SAndroid Build Coastguard Worker 
ExportOutputWithSeveralOutputSlotConnectionsTest(std::vector<BackendId> backends)689*89c4ff92SAndroid Build Coastguard Worker inline void ExportOutputWithSeveralOutputSlotConnectionsTest(std::vector<BackendId> backends)
690*89c4ff92SAndroid Build Coastguard Worker {
691*89c4ff92SAndroid Build Coastguard Worker     using namespace armnn;
692*89c4ff92SAndroid Build Coastguard Worker 
693*89c4ff92SAndroid Build Coastguard Worker     // Create runtime in which test will run
694*89c4ff92SAndroid Build Coastguard Worker     IRuntime::CreationOptions options;
695*89c4ff92SAndroid Build Coastguard Worker     IRuntimePtr runtime(armnn::IRuntime::Create(options));
696*89c4ff92SAndroid Build Coastguard Worker 
697*89c4ff92SAndroid Build Coastguard Worker     // build up the structure of the network
698*89c4ff92SAndroid Build Coastguard Worker     INetworkPtr net(INetwork::Create());
699*89c4ff92SAndroid Build Coastguard Worker 
700*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* input = net->AddInputLayer(0);
701*89c4ff92SAndroid Build Coastguard Worker 
702*89c4ff92SAndroid Build Coastguard Worker     ActivationDescriptor descriptor;
703*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_Function = ActivationFunction::Square;
704*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* activation = net->AddActivationLayer(descriptor);
705*89c4ff92SAndroid Build Coastguard Worker 
706*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* output0 = net->AddOutputLayer(0);
707*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* output1 = net->AddOutputLayer(1);
708*89c4ff92SAndroid Build Coastguard Worker 
709*89c4ff92SAndroid Build Coastguard Worker     input->GetOutputSlot(0).Connect(activation->GetInputSlot(0));
710*89c4ff92SAndroid Build Coastguard Worker     activation->GetOutputSlot(0).Connect(output0->GetInputSlot(0));
711*89c4ff92SAndroid Build Coastguard Worker     activation->GetOutputSlot(0).Connect(output1->GetInputSlot(0));
712*89c4ff92SAndroid Build Coastguard Worker 
713*89c4ff92SAndroid Build Coastguard Worker     input->GetOutputSlot(0).SetTensorInfo(TensorInfo({ 1, 1, 4, 1 }, DataType::Float32, 0.0f, 0, true));
714*89c4ff92SAndroid Build Coastguard Worker     activation->GetOutputSlot(0).SetTensorInfo(TensorInfo({ 1, 1, 4, 1 }, DataType::Float32));
715*89c4ff92SAndroid Build Coastguard Worker 
716*89c4ff92SAndroid Build Coastguard Worker     // Optimize the network
717*89c4ff92SAndroid Build Coastguard Worker     OptimizerOptionsOpaque optimizedOptions;
718*89c4ff92SAndroid Build Coastguard Worker     optimizedOptions.SetImportEnabled(true);
719*89c4ff92SAndroid Build Coastguard Worker     optimizedOptions.SetExportEnabled(true);
720*89c4ff92SAndroid Build Coastguard Worker     IOptimizedNetworkPtr optNet = Optimize(*net, backends, runtime->GetDeviceSpec(), optimizedOptions);
721*89c4ff92SAndroid Build Coastguard Worker 
722*89c4ff92SAndroid Build Coastguard Worker     // Loads it into the runtime.
723*89c4ff92SAndroid Build Coastguard Worker     NetworkId netId;
724*89c4ff92SAndroid Build Coastguard Worker     std::string errorMessage;
725*89c4ff92SAndroid Build Coastguard Worker     // Enable Importing
726*89c4ff92SAndroid Build Coastguard Worker     INetworkProperties networkProperties(false, MemorySource::Malloc, MemorySource::Malloc);
727*89c4ff92SAndroid Build Coastguard Worker     armnn::Status loadingStatus = runtime->LoadNetwork(netId, std::move(optNet), errorMessage, networkProperties);
728*89c4ff92SAndroid Build Coastguard Worker     CHECK_MESSAGE(loadingStatus == Status::Success, errorMessage);
729*89c4ff92SAndroid Build Coastguard Worker 
730*89c4ff92SAndroid Build Coastguard Worker     // Creates structures for input & output
731*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputData
732*89c4ff92SAndroid Build Coastguard Worker     {
733*89c4ff92SAndroid Build Coastguard Worker         1.0f, 2.0f, 3.0f, 4.0f
734*89c4ff92SAndroid Build Coastguard Worker     };
735*89c4ff92SAndroid Build Coastguard Worker 
736*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> outputData0(4);
737*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> outputData1(4);
738*89c4ff92SAndroid Build Coastguard Worker 
739*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expectedOutput
740*89c4ff92SAndroid Build Coastguard Worker     {
741*89c4ff92SAndroid Build Coastguard Worker          1.0f, 4.0f, 9.0f, 16.0f
742*89c4ff92SAndroid Build Coastguard Worker     };
743*89c4ff92SAndroid Build Coastguard Worker 
744*89c4ff92SAndroid Build Coastguard Worker     InputTensors inputTensors
745*89c4ff92SAndroid Build Coastguard Worker     {
746*89c4ff92SAndroid Build Coastguard Worker         {0,armnn::ConstTensor(runtime->GetInputTensorInfo(netId, 0), inputData.data())},
747*89c4ff92SAndroid Build Coastguard Worker     };
748*89c4ff92SAndroid Build Coastguard Worker     OutputTensors outputTensors
749*89c4ff92SAndroid Build Coastguard Worker     {
750*89c4ff92SAndroid Build Coastguard Worker         {0,armnn::Tensor(runtime->GetOutputTensorInfo(netId, 0), outputData0.data())},
751*89c4ff92SAndroid Build Coastguard Worker         {1,armnn::Tensor(runtime->GetOutputTensorInfo(netId, 1), outputData1.data())}
752*89c4ff92SAndroid Build Coastguard Worker     };
753*89c4ff92SAndroid Build Coastguard Worker 
754*89c4ff92SAndroid Build Coastguard Worker     // The result of the inference is not important, just the fact that there
755*89c4ff92SAndroid Build Coastguard Worker     // should not be CopyMemGeneric workloads.
756*89c4ff92SAndroid Build Coastguard Worker     runtime->GetProfiler(netId)->EnableProfiling(true);
757*89c4ff92SAndroid Build Coastguard Worker 
758*89c4ff92SAndroid Build Coastguard Worker     // Do the inference
759*89c4ff92SAndroid Build Coastguard Worker     runtime->EnqueueWorkload(netId, inputTensors, outputTensors);
760*89c4ff92SAndroid Build Coastguard Worker 
761*89c4ff92SAndroid Build Coastguard Worker     // Retrieve the Profiler.Print() output to get the workload execution
762*89c4ff92SAndroid Build Coastguard Worker     ProfilerManager& profilerManager = armnn::ProfilerManager::GetInstance();
763*89c4ff92SAndroid Build Coastguard Worker     std::stringstream ss;
764*89c4ff92SAndroid Build Coastguard Worker     profilerManager.GetProfiler()->Print(ss);
765*89c4ff92SAndroid Build Coastguard Worker     std::string dump = ss.str();
766*89c4ff92SAndroid Build Coastguard Worker 
767*89c4ff92SAndroid Build Coastguard Worker     std::size_t found = std::string::npos;
768*89c4ff92SAndroid Build Coastguard Worker 
769*89c4ff92SAndroid Build Coastguard Worker     if (backends[0] == Compute::CpuRef)
770*89c4ff92SAndroid Build Coastguard Worker     {
771*89c4ff92SAndroid Build Coastguard Worker         found = dump.find("RefActivationWorkload");
772*89c4ff92SAndroid Build Coastguard Worker     }
773*89c4ff92SAndroid Build Coastguard Worker     else if (backends[0] == Compute::CpuAcc)
774*89c4ff92SAndroid Build Coastguard Worker     {
775*89c4ff92SAndroid Build Coastguard Worker         found = dump.find("NeonActivationWorkload");
776*89c4ff92SAndroid Build Coastguard Worker     }
777*89c4ff92SAndroid Build Coastguard Worker     else if (backends[0] == Compute::GpuAcc)
778*89c4ff92SAndroid Build Coastguard Worker     {
779*89c4ff92SAndroid Build Coastguard Worker         found = dump.find("ClActivationWorkload");
780*89c4ff92SAndroid Build Coastguard Worker     }
781*89c4ff92SAndroid Build Coastguard Worker 
782*89c4ff92SAndroid Build Coastguard Worker     CHECK(found != std::string::npos);
783*89c4ff92SAndroid Build Coastguard Worker     // No contains SyncMemGeneric
784*89c4ff92SAndroid Build Coastguard Worker     found = dump.find("SyncMemGeneric");
785*89c4ff92SAndroid Build Coastguard Worker     CHECK(found == std::string::npos);
786*89c4ff92SAndroid Build Coastguard Worker     // Contains CopyMemGeneric
787*89c4ff92SAndroid Build Coastguard Worker     found = dump.find("CopyMemGeneric");
788*89c4ff92SAndroid Build Coastguard Worker     CHECK(found != std::string::npos);
789*89c4ff92SAndroid Build Coastguard Worker 
790*89c4ff92SAndroid Build Coastguard Worker     // Check that the outputs are correct
791*89c4ff92SAndroid Build Coastguard Worker     CHECK(std::equal(outputData0.begin(), outputData0.end(),
792*89c4ff92SAndroid Build Coastguard Worker                                   expectedOutput.begin(), expectedOutput.end()));
793*89c4ff92SAndroid Build Coastguard Worker     CHECK(std::equal(outputData1.begin(), outputData1.end(),
794*89c4ff92SAndroid Build Coastguard Worker                                   expectedOutput.begin(), expectedOutput.end()));
795*89c4ff92SAndroid Build Coastguard Worker }
796*89c4ff92SAndroid Build Coastguard Worker 
StridedSliceInvalidSliceEndToEndTest(std::vector<BackendId> backends)797*89c4ff92SAndroid Build Coastguard Worker inline void StridedSliceInvalidSliceEndToEndTest(std::vector<BackendId> backends)
798*89c4ff92SAndroid Build Coastguard Worker {
799*89c4ff92SAndroid Build Coastguard Worker     using namespace armnn;
800*89c4ff92SAndroid Build Coastguard Worker 
801*89c4ff92SAndroid Build Coastguard Worker     // Create runtime in which test will run
802*89c4ff92SAndroid Build Coastguard Worker     IRuntime::CreationOptions options;
803*89c4ff92SAndroid Build Coastguard Worker     IRuntimePtr runtime(armnn::IRuntime::Create(options));
804*89c4ff92SAndroid Build Coastguard Worker 
805*89c4ff92SAndroid Build Coastguard Worker     // build up the structure of the network
806*89c4ff92SAndroid Build Coastguard Worker     INetworkPtr net(INetwork::Create());
807*89c4ff92SAndroid Build Coastguard Worker 
808*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* input = net->AddInputLayer(0);
809*89c4ff92SAndroid Build Coastguard Worker 
810*89c4ff92SAndroid Build Coastguard Worker     // Configure a strided slice with a stride the same size as the input but with a ShrinkAxisMask on the first
811*89c4ff92SAndroid Build Coastguard Worker     // dim of the output to make it too small to hold the specified slice.
812*89c4ff92SAndroid Build Coastguard Worker     StridedSliceDescriptor descriptor;
813*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_Begin          = {0, 0};
814*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_End            = {2, 3};
815*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_Stride         = {1, 1};
816*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_BeginMask      = 0;
817*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_EndMask        = 0;
818*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_ShrinkAxisMask = 1;
819*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* stridedSlice = net->AddStridedSliceLayer(descriptor);
820*89c4ff92SAndroid Build Coastguard Worker 
821*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* output0 = net->AddOutputLayer(0);
822*89c4ff92SAndroid Build Coastguard Worker 
823*89c4ff92SAndroid Build Coastguard Worker     input->GetOutputSlot(0).Connect(stridedSlice->GetInputSlot(0));
824*89c4ff92SAndroid Build Coastguard Worker     stridedSlice->GetOutputSlot(0).Connect(output0->GetInputSlot(0));
825*89c4ff92SAndroid Build Coastguard Worker 
826*89c4ff92SAndroid Build Coastguard Worker     input->GetOutputSlot(0).SetTensorInfo(TensorInfo({ 2, 3 }, DataType::Float32, 0.0f, 0, true));
827*89c4ff92SAndroid Build Coastguard Worker     stridedSlice->GetOutputSlot(0).SetTensorInfo(TensorInfo({ 3 }, DataType::Float32));
828*89c4ff92SAndroid Build Coastguard Worker 
829*89c4ff92SAndroid Build Coastguard Worker     // Attempt to optimize the network and check that the correct exception is thrown
830*89c4ff92SAndroid Build Coastguard Worker     CHECK_THROWS_AS(Optimize(*net, backends, runtime->GetDeviceSpec()), armnn::LayerValidationException);
831*89c4ff92SAndroid Build Coastguard Worker }
832*89c4ff92SAndroid Build Coastguard Worker 
ForceImportWithAlignedBuffersEndToEndTest(std::vector<BackendId> backends)833*89c4ff92SAndroid Build Coastguard Worker inline void ForceImportWithAlignedBuffersEndToEndTest(std::vector<BackendId> backends)
834*89c4ff92SAndroid Build Coastguard Worker {
835*89c4ff92SAndroid Build Coastguard Worker     /**
836*89c4ff92SAndroid Build Coastguard Worker      * This test is similar to the Import tests above, we create a network with a square function and pass in a vector
837*89c4ff92SAndroid Build Coastguard Worker      * with 4 floats, square them. and validate the output. We then check the profiling logs to see if input/output
838*89c4ff92SAndroid Build Coastguard Worker      * tensors are copied (CopyMemGeneric) or imported (SyncMemGeneric)
839*89c4ff92SAndroid Build Coastguard Worker      * In this case all inputs and outputs should be imported
840*89c4ff92SAndroid Build Coastguard Worker      */
841*89c4ff92SAndroid Build Coastguard Worker     using namespace armnn;
842*89c4ff92SAndroid Build Coastguard Worker     IRuntime::CreationOptions options;
843*89c4ff92SAndroid Build Coastguard Worker     IRuntimePtr runtime(IRuntime::Create(options));
844*89c4ff92SAndroid Build Coastguard Worker 
845*89c4ff92SAndroid Build Coastguard Worker     // Builds up the structure of the network.
846*89c4ff92SAndroid Build Coastguard Worker     INetworkPtr net(INetwork::Create());
847*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* input = net->AddInputLayer(0);
848*89c4ff92SAndroid Build Coastguard Worker     ActivationDescriptor descriptor;
849*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_Function = ActivationFunction::Square;
850*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* activationLayer = net->AddActivationLayer(descriptor);
851*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* output = net->AddOutputLayer(0);
852*89c4ff92SAndroid Build Coastguard Worker     input->GetOutputSlot(0).Connect(activationLayer->GetInputSlot(0));
853*89c4ff92SAndroid Build Coastguard Worker     activationLayer->GetOutputSlot(0).Connect(output->GetInputSlot(0));
854*89c4ff92SAndroid Build Coastguard Worker     input->GetOutputSlot(0).SetTensorInfo(TensorInfo({ 1, 1, 1, 4 }, DataType::Float32, 0.0f, 0, true));
855*89c4ff92SAndroid Build Coastguard Worker     activationLayer->GetOutputSlot(0).SetTensorInfo(TensorInfo({ 1, 1, 1, 4 }, DataType::Float32));
856*89c4ff92SAndroid Build Coastguard Worker     IOptimizedNetworkPtr optNet = Optimize(*net, backends, runtime->GetDeviceSpec());
857*89c4ff92SAndroid Build Coastguard Worker     INFO("Load Network");
858*89c4ff92SAndroid Build Coastguard Worker 
859*89c4ff92SAndroid Build Coastguard Worker     // Load it into the runtime. It should pass.
860*89c4ff92SAndroid Build Coastguard Worker     NetworkId netId;
861*89c4ff92SAndroid Build Coastguard Worker     std::string errorMessage;
862*89c4ff92SAndroid Build Coastguard Worker     INetworkProperties networkProperties(false, MemorySource::Undefined, MemorySource::Undefined);
863*89c4ff92SAndroid Build Coastguard Worker     armnn::Status loadingStatus = runtime->LoadNetwork(netId, std::move(optNet), errorMessage, networkProperties);
864*89c4ff92SAndroid Build Coastguard Worker     CHECK_MESSAGE(loadingStatus == Status::Success, errorMessage);
865*89c4ff92SAndroid Build Coastguard Worker 
866*89c4ff92SAndroid Build Coastguard Worker     INFO("Generate Data");
867*89c4ff92SAndroid Build Coastguard Worker 
868*89c4ff92SAndroid Build Coastguard Worker     // Creates structures for input & output
869*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputData
870*89c4ff92SAndroid Build Coastguard Worker     {
871*89c4ff92SAndroid Build Coastguard Worker         1.0f, 2.0f, 3.0f, 4.0f
872*89c4ff92SAndroid Build Coastguard Worker     };
873*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> outputData(4);
874*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expectedOutput
875*89c4ff92SAndroid Build Coastguard Worker     {
876*89c4ff92SAndroid Build Coastguard Worker          1.0f, 4.0f, 9.0f, 16.0f
877*89c4ff92SAndroid Build Coastguard Worker     };
878*89c4ff92SAndroid Build Coastguard Worker 
879*89c4ff92SAndroid Build Coastguard Worker     // Check our input and output pointers are actually aligned
880*89c4ff92SAndroid Build Coastguard Worker     uintptr_t alignment = GetDataTypeSize(DataType::Float32);
881*89c4ff92SAndroid Build Coastguard Worker     CHECK(!(reinterpret_cast<uintptr_t>(inputData.data()) % alignment));
882*89c4ff92SAndroid Build Coastguard Worker     CHECK(!(reinterpret_cast<uintptr_t>(outputData.data()) % alignment));
883*89c4ff92SAndroid Build Coastguard Worker 
884*89c4ff92SAndroid Build Coastguard Worker     INFO("Create Inference");
885*89c4ff92SAndroid Build Coastguard Worker     InputTensors inputTensors
886*89c4ff92SAndroid Build Coastguard Worker     {
887*89c4ff92SAndroid Build Coastguard Worker         {0,armnn::ConstTensor(runtime->GetInputTensorInfo(netId, 0), inputData.data())},
888*89c4ff92SAndroid Build Coastguard Worker     };
889*89c4ff92SAndroid Build Coastguard Worker     OutputTensors outputTensors
890*89c4ff92SAndroid Build Coastguard Worker     {
891*89c4ff92SAndroid Build Coastguard Worker         {0,armnn::Tensor(runtime->GetOutputTensorInfo(netId, 0), outputData.data())}
892*89c4ff92SAndroid Build Coastguard Worker     };
893*89c4ff92SAndroid Build Coastguard Worker 
894*89c4ff92SAndroid Build Coastguard Worker     runtime->GetProfiler(netId)->EnableProfiling(true);
895*89c4ff92SAndroid Build Coastguard Worker     std::vector<ImportedInputId> importedInputIds =
896*89c4ff92SAndroid Build Coastguard Worker         runtime->ImportInputs(netId, inputTensors, MemorySource::Malloc);
897*89c4ff92SAndroid Build Coastguard Worker     CHECK(importedInputIds.size() == 1);
898*89c4ff92SAndroid Build Coastguard Worker     std::vector<ImportedOutputId> importedOutputIds =
899*89c4ff92SAndroid Build Coastguard Worker         runtime->ImportOutputs(netId, outputTensors, MemorySource::Malloc);
900*89c4ff92SAndroid Build Coastguard Worker     CHECK(importedOutputIds.size() == 1);
901*89c4ff92SAndroid Build Coastguard Worker     // Do the inference and force the import as the memory is aligned.
902*89c4ff92SAndroid Build Coastguard Worker     runtime->EnqueueWorkload(netId, InputTensors(), OutputTensors(), importedInputIds, importedOutputIds);
903*89c4ff92SAndroid Build Coastguard Worker 
904*89c4ff92SAndroid Build Coastguard Worker     // Retrieve the Profiler.Print() output to get the workload execution
905*89c4ff92SAndroid Build Coastguard Worker     ProfilerManager& profilerManager = armnn::ProfilerManager::GetInstance();
906*89c4ff92SAndroid Build Coastguard Worker     std::stringstream ss;
907*89c4ff92SAndroid Build Coastguard Worker     profilerManager.GetProfiler()->Print(ss);
908*89c4ff92SAndroid Build Coastguard Worker     std::string dump = ss.str();
909*89c4ff92SAndroid Build Coastguard Worker 
910*89c4ff92SAndroid Build Coastguard Worker     if (backends[0] == Compute::CpuAcc)
911*89c4ff92SAndroid Build Coastguard Worker     {
912*89c4ff92SAndroid Build Coastguard Worker         // Reconfigure has not been implemented for CpuAcc so it will always copy, this will break whenever
913*89c4ff92SAndroid Build Coastguard Worker         // reconfigure is implemented
914*89c4ff92SAndroid Build Coastguard Worker         int count = SubStringCounter(dump, "SyncMemGeneric");
915*89c4ff92SAndroid Build Coastguard Worker         CHECK(count == 0);
916*89c4ff92SAndroid Build Coastguard Worker         // Should be 2 CopyMemGeneric workloads
917*89c4ff92SAndroid Build Coastguard Worker         count = SubStringCounter(dump, "CopyMemGeneric");
918*89c4ff92SAndroid Build Coastguard Worker         CHECK(count == 2);
919*89c4ff92SAndroid Build Coastguard Worker     }
920*89c4ff92SAndroid Build Coastguard Worker     else
921*89c4ff92SAndroid Build Coastguard Worker     {
922*89c4ff92SAndroid Build Coastguard Worker         // Check there is a SyncMemGeneric workload as we exported
923*89c4ff92SAndroid Build Coastguard Worker         int count = SubStringCounter(dump, "SyncMemGeneric");
924*89c4ff92SAndroid Build Coastguard Worker         CHECK(count == 1);
925*89c4ff92SAndroid Build Coastguard Worker         // Shouldn't be any CopyMemGeneric workloads
926*89c4ff92SAndroid Build Coastguard Worker         count = SubStringCounter(dump, "CopyMemGeneric");
927*89c4ff92SAndroid Build Coastguard Worker         CHECK(count == 0);
928*89c4ff92SAndroid Build Coastguard Worker     }
929*89c4ff92SAndroid Build Coastguard Worker     // Check the output is correct
930*89c4ff92SAndroid Build Coastguard Worker     CHECK(std::equal(outputData.begin(), outputData.end(), expectedOutput.begin(), expectedOutput.end()));
931*89c4ff92SAndroid Build Coastguard Worker }
932*89c4ff92SAndroid Build Coastguard Worker 
ForceImportWithMisalignedInputBuffersEndToEndTest(std::vector<BackendId> backends)933*89c4ff92SAndroid Build Coastguard Worker inline void ForceImportWithMisalignedInputBuffersEndToEndTest(std::vector<BackendId> backends)
934*89c4ff92SAndroid Build Coastguard Worker {
935*89c4ff92SAndroid Build Coastguard Worker     /**
936*89c4ff92SAndroid Build Coastguard Worker      * This test is similar to the Import tests above, we create a network with a square function and pass in a vector
937*89c4ff92SAndroid Build Coastguard Worker      * with 4 floats, square them. and validate the output. We then check the profiling logs to see if input/output
938*89c4ff92SAndroid Build Coastguard Worker      * tensors are copied (CopyMemGeneric) or imported (SyncMemGeneric)
939*89c4ff92SAndroid Build Coastguard Worker      * In this case all only the output should be imported
940*89c4ff92SAndroid Build Coastguard Worker      */
941*89c4ff92SAndroid Build Coastguard Worker     using namespace armnn;
942*89c4ff92SAndroid Build Coastguard Worker 
943*89c4ff92SAndroid Build Coastguard Worker     IRuntime::CreationOptions options;
944*89c4ff92SAndroid Build Coastguard Worker     IRuntimePtr runtime(IRuntime::Create(options));
945*89c4ff92SAndroid Build Coastguard Worker 
946*89c4ff92SAndroid Build Coastguard Worker     // Builds up the structure of the network.
947*89c4ff92SAndroid Build Coastguard Worker     INetworkPtr net(INetwork::Create());
948*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* input = net->AddInputLayer(0);
949*89c4ff92SAndroid Build Coastguard Worker 
950*89c4ff92SAndroid Build Coastguard Worker     ActivationDescriptor descriptor;
951*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_Function = ActivationFunction::Square;
952*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* activationLayer = net->AddActivationLayer(descriptor);
953*89c4ff92SAndroid Build Coastguard Worker 
954*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* output = net->AddOutputLayer(0);
955*89c4ff92SAndroid Build Coastguard Worker 
956*89c4ff92SAndroid Build Coastguard Worker     input->GetOutputSlot(0).Connect(activationLayer->GetInputSlot(0));
957*89c4ff92SAndroid Build Coastguard Worker     activationLayer->GetOutputSlot(0).Connect(output->GetInputSlot(0));
958*89c4ff92SAndroid Build Coastguard Worker     input->GetOutputSlot(0).SetTensorInfo(TensorInfo({ 1, 1, 1, 4 }, DataType::Float32, 0.0f, 0, true));
959*89c4ff92SAndroid Build Coastguard Worker     activationLayer->GetOutputSlot(0).SetTensorInfo(TensorInfo({ 1, 1, 1, 4 }, DataType::Float32));
960*89c4ff92SAndroid Build Coastguard Worker 
961*89c4ff92SAndroid Build Coastguard Worker     IOptimizedNetworkPtr optNet = Optimize(*net, backends, runtime->GetDeviceSpec());
962*89c4ff92SAndroid Build Coastguard Worker     INFO("Load Network");
963*89c4ff92SAndroid Build Coastguard Worker     // Load it into the runtime. It should pass.
964*89c4ff92SAndroid Build Coastguard Worker     NetworkId netId;
965*89c4ff92SAndroid Build Coastguard Worker     std::string errorMessage;
966*89c4ff92SAndroid Build Coastguard Worker     INetworkProperties networkProperties(false, MemorySource::Undefined, MemorySource::Undefined);
967*89c4ff92SAndroid Build Coastguard Worker     armnn::Status loadingStatus = runtime->LoadNetwork(netId, std::move(optNet), errorMessage, networkProperties);
968*89c4ff92SAndroid Build Coastguard Worker     CHECK_MESSAGE(loadingStatus == Status::Success, errorMessage);
969*89c4ff92SAndroid Build Coastguard Worker 
970*89c4ff92SAndroid Build Coastguard Worker     INFO("Generate Data");
971*89c4ff92SAndroid Build Coastguard Worker 
972*89c4ff92SAndroid Build Coastguard Worker     // This code looks a little funky but the idea is to create a buffer of floats but offset by the size of a char
973*89c4ff92SAndroid Build Coastguard Worker     // this will guarantee that the resultant buffer is misaligned and thus should always be copied.
974*89c4ff92SAndroid Build Coastguard Worker     auto memPtr = std::malloc(4 * sizeof(float) + sizeof(char));
975*89c4ff92SAndroid Build Coastguard Worker 
976*89c4ff92SAndroid Build Coastguard Worker     float* misalignedMemPtr = reinterpret_cast<float*>(reinterpret_cast<char*>(memPtr) + 1);
977*89c4ff92SAndroid Build Coastguard Worker 
978*89c4ff92SAndroid Build Coastguard Worker     // Check if our pointer is truly misaligned
979*89c4ff92SAndroid Build Coastguard Worker     uintptr_t alignment = GetDataTypeSize(DataType::Float32);
980*89c4ff92SAndroid Build Coastguard Worker     CHECK (reinterpret_cast<uintptr_t>(misalignedMemPtr) % alignment);
981*89c4ff92SAndroid Build Coastguard Worker 
982*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputData
983*89c4ff92SAndroid Build Coastguard Worker     {
984*89c4ff92SAndroid Build Coastguard Worker          1.0f, 2.0f, 3.0f, 4.0f
985*89c4ff92SAndroid Build Coastguard Worker     };
986*89c4ff92SAndroid Build Coastguard Worker 
987*89c4ff92SAndroid Build Coastguard Worker     std::memcpy(misalignedMemPtr, inputData.data(), 4*sizeof(float));
988*89c4ff92SAndroid Build Coastguard Worker 
989*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> outputData(4);
990*89c4ff92SAndroid Build Coastguard Worker     // Check our output buffer is aligned
991*89c4ff92SAndroid Build Coastguard Worker     CHECK(!(reinterpret_cast<uintptr_t>(outputData.data()) % alignment));
992*89c4ff92SAndroid Build Coastguard Worker 
993*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expectedOutput
994*89c4ff92SAndroid Build Coastguard Worker     {
995*89c4ff92SAndroid Build Coastguard Worker          1.0f, 4.0f, 9.0f, 16.0f
996*89c4ff92SAndroid Build Coastguard Worker     };
997*89c4ff92SAndroid Build Coastguard Worker 
998*89c4ff92SAndroid Build Coastguard Worker     INFO("Create Inference");
999*89c4ff92SAndroid Build Coastguard Worker     InputTensors inputTensors
1000*89c4ff92SAndroid Build Coastguard Worker     {
1001*89c4ff92SAndroid Build Coastguard Worker         {0,armnn::ConstTensor(runtime->GetInputTensorInfo(netId, 0), misalignedMemPtr)},
1002*89c4ff92SAndroid Build Coastguard Worker     };
1003*89c4ff92SAndroid Build Coastguard Worker     OutputTensors outputTensors
1004*89c4ff92SAndroid Build Coastguard Worker     {
1005*89c4ff92SAndroid Build Coastguard Worker         {0,armnn::Tensor(runtime->GetOutputTensorInfo(netId, 0), outputData.data())}
1006*89c4ff92SAndroid Build Coastguard Worker     };
1007*89c4ff92SAndroid Build Coastguard Worker     runtime->GetProfiler(netId)->EnableProfiling(true);
1008*89c4ff92SAndroid Build Coastguard Worker     std::vector<ImportedInputId> importedInputIds =
1009*89c4ff92SAndroid Build Coastguard Worker         runtime->ImportInputs(netId, inputTensors, MemorySource::Malloc);
1010*89c4ff92SAndroid Build Coastguard Worker     // We expect the import to have failed.
1011*89c4ff92SAndroid Build Coastguard Worker     CHECK(importedInputIds.size() == 0);
1012*89c4ff92SAndroid Build Coastguard Worker     std::vector<ImportedOutputId> importedOutputIds =
1013*89c4ff92SAndroid Build Coastguard Worker         runtime->ImportOutputs(netId, outputTensors, MemorySource::Malloc);
1014*89c4ff92SAndroid Build Coastguard Worker     CHECK(importedOutputIds.size() == 1);
1015*89c4ff92SAndroid Build Coastguard Worker 
1016*89c4ff92SAndroid Build Coastguard Worker     // Do the inference and force the import as the memory is misaligned.
1017*89c4ff92SAndroid Build Coastguard Worker     runtime->EnqueueWorkload(netId, inputTensors, OutputTensors(), importedInputIds, importedOutputIds);
1018*89c4ff92SAndroid Build Coastguard Worker 
1019*89c4ff92SAndroid Build Coastguard Worker     // Retrieve the Profiler.Print() output to get the workload execution
1020*89c4ff92SAndroid Build Coastguard Worker     ProfilerManager& profilerManager = armnn::ProfilerManager::GetInstance();
1021*89c4ff92SAndroid Build Coastguard Worker     std::stringstream ss;
1022*89c4ff92SAndroid Build Coastguard Worker     profilerManager.GetProfiler()->Print(ss);
1023*89c4ff92SAndroid Build Coastguard Worker     std::string dump = ss.str();
1024*89c4ff92SAndroid Build Coastguard Worker 
1025*89c4ff92SAndroid Build Coastguard Worker     // GpuAcc is a different case to CpuRef and CpuAcc, it doesn't use the buffer directly but instead maps it to a
1026*89c4ff92SAndroid Build Coastguard Worker     // new set of addresses within Gpu Memory. This will almost always be auto-aligned, so we don't need to check
1027*89c4ff92SAndroid Build Coastguard Worker     // for imports/copies. Only that the output is correct.
1028*89c4ff92SAndroid Build Coastguard Worker     if (backends[0] != Compute::GpuAcc)
1029*89c4ff92SAndroid Build Coastguard Worker     {
1030*89c4ff92SAndroid Build Coastguard Worker         if (backends[0] == Compute::CpuAcc)
1031*89c4ff92SAndroid Build Coastguard Worker         {
1032*89c4ff92SAndroid Build Coastguard Worker             // Reconfigure has not been implemented for CpuAcc so it will always copy, this will break whenever
1033*89c4ff92SAndroid Build Coastguard Worker             // reconfigure is implemented
1034*89c4ff92SAndroid Build Coastguard Worker             // We should get 0 SyncMemGeneric for the Output
1035*89c4ff92SAndroid Build Coastguard Worker             int count = SubStringCounter(dump, "SyncMemGeneric");
1036*89c4ff92SAndroid Build Coastguard Worker             CHECK(count == 0);
1037*89c4ff92SAndroid Build Coastguard Worker             // Should be 2 CopyMemGeneric as we copied the input
1038*89c4ff92SAndroid Build Coastguard Worker             count = SubStringCounter(dump, "CopyMemGeneric");
1039*89c4ff92SAndroid Build Coastguard Worker             CHECK(count == 2);
1040*89c4ff92SAndroid Build Coastguard Worker         }
1041*89c4ff92SAndroid Build Coastguard Worker         else
1042*89c4ff92SAndroid Build Coastguard Worker         {
1043*89c4ff92SAndroid Build Coastguard Worker             // We should get 1 SyncMemGeneric for the Output
1044*89c4ff92SAndroid Build Coastguard Worker             int count = SubStringCounter(dump, "SyncMemGeneric");
1045*89c4ff92SAndroid Build Coastguard Worker             CHECK(count == 1);
1046*89c4ff92SAndroid Build Coastguard Worker             // Should only be 1 CopyMemGeneric as we copied the input
1047*89c4ff92SAndroid Build Coastguard Worker             count = SubStringCounter(dump, "CopyMemGeneric");
1048*89c4ff92SAndroid Build Coastguard Worker             CHECK(count == 1);
1049*89c4ff92SAndroid Build Coastguard Worker         }
1050*89c4ff92SAndroid Build Coastguard Worker     }
1051*89c4ff92SAndroid Build Coastguard Worker     // Check the output is correct
1052*89c4ff92SAndroid Build Coastguard Worker     CHECK(std::equal(outputData.begin(), outputData.end(), expectedOutput.begin(), expectedOutput.end()));
1053*89c4ff92SAndroid Build Coastguard Worker     std::free(memPtr);
1054*89c4ff92SAndroid Build Coastguard Worker }
1055*89c4ff92SAndroid Build Coastguard Worker 
ForceImportWithMisalignedOutputBuffersEndToEndTest(std::vector<BackendId> backends)1056*89c4ff92SAndroid Build Coastguard Worker inline void ForceImportWithMisalignedOutputBuffersEndToEndTest(std::vector<BackendId> backends)
1057*89c4ff92SAndroid Build Coastguard Worker {
1058*89c4ff92SAndroid Build Coastguard Worker     /**
1059*89c4ff92SAndroid Build Coastguard Worker      * This test is similar to the Import tests above, we create a network with a square function and pass in a vector
1060*89c4ff92SAndroid Build Coastguard Worker      * with 4 floats, square them. and validate the output. We then check the profiling logs to see if input/output
1061*89c4ff92SAndroid Build Coastguard Worker      * tensors are copied (CopyMemGeneric) or imported (SyncMemGeneric)
1062*89c4ff92SAndroid Build Coastguard Worker      * In this case all only the input should be imported
1063*89c4ff92SAndroid Build Coastguard Worker      */
1064*89c4ff92SAndroid Build Coastguard Worker     using namespace armnn;
1065*89c4ff92SAndroid Build Coastguard Worker 
1066*89c4ff92SAndroid Build Coastguard Worker     IRuntime::CreationOptions options;
1067*89c4ff92SAndroid Build Coastguard Worker     IRuntimePtr runtime(IRuntime::Create(options));
1068*89c4ff92SAndroid Build Coastguard Worker 
1069*89c4ff92SAndroid Build Coastguard Worker     // Builds up the structure of the network.
1070*89c4ff92SAndroid Build Coastguard Worker     INetworkPtr net(INetwork::Create());
1071*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* input = net->AddInputLayer(0);
1072*89c4ff92SAndroid Build Coastguard Worker 
1073*89c4ff92SAndroid Build Coastguard Worker     ActivationDescriptor descriptor;
1074*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_Function = ActivationFunction::Square;
1075*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* activationLayer = net->AddActivationLayer(descriptor);
1076*89c4ff92SAndroid Build Coastguard Worker 
1077*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* output = net->AddOutputLayer(0);
1078*89c4ff92SAndroid Build Coastguard Worker 
1079*89c4ff92SAndroid Build Coastguard Worker     input->GetOutputSlot(0).Connect(activationLayer->GetInputSlot(0));
1080*89c4ff92SAndroid Build Coastguard Worker     activationLayer->GetOutputSlot(0).Connect(output->GetInputSlot(0));
1081*89c4ff92SAndroid Build Coastguard Worker     input->GetOutputSlot(0).SetTensorInfo(TensorInfo({ 1, 1, 1, 4 }, DataType::Float32, 0.0f, 0, true));
1082*89c4ff92SAndroid Build Coastguard Worker     activationLayer->GetOutputSlot(0).SetTensorInfo(TensorInfo({ 1, 1, 1, 4 }, DataType::Float32));
1083*89c4ff92SAndroid Build Coastguard Worker 
1084*89c4ff92SAndroid Build Coastguard Worker     IOptimizedNetworkPtr optNet = Optimize(*net, backends, runtime->GetDeviceSpec());
1085*89c4ff92SAndroid Build Coastguard Worker     INFO("Load Network");
1086*89c4ff92SAndroid Build Coastguard Worker     // Load it into the runtime. It should pass.
1087*89c4ff92SAndroid Build Coastguard Worker     NetworkId netId;
1088*89c4ff92SAndroid Build Coastguard Worker     std::string errorMessage;
1089*89c4ff92SAndroid Build Coastguard Worker     INetworkProperties networkProperties(false, MemorySource::Undefined, MemorySource::Undefined);
1090*89c4ff92SAndroid Build Coastguard Worker     armnn::Status loadingStatus = runtime->LoadNetwork(netId, std::move(optNet), errorMessage, networkProperties);
1091*89c4ff92SAndroid Build Coastguard Worker     CHECK_MESSAGE(loadingStatus == Status::Success, errorMessage);
1092*89c4ff92SAndroid Build Coastguard Worker 
1093*89c4ff92SAndroid Build Coastguard Worker     INFO("Generate Data");
1094*89c4ff92SAndroid Build Coastguard Worker 
1095*89c4ff92SAndroid Build Coastguard Worker     // This code looks a little funky but the idea is to create a buffer of floats but offset by the size of a char
1096*89c4ff92SAndroid Build Coastguard Worker     // this will guarantee that the resultant buffer is misaligned and thus should always be copied.
1097*89c4ff92SAndroid Build Coastguard Worker     auto memPtr = std::malloc(4 * sizeof(float) + sizeof(char));
1098*89c4ff92SAndroid Build Coastguard Worker 
1099*89c4ff92SAndroid Build Coastguard Worker     float* misalignedMemPtr = reinterpret_cast<float*>(reinterpret_cast<char*>(memPtr) + 1);
1100*89c4ff92SAndroid Build Coastguard Worker 
1101*89c4ff92SAndroid Build Coastguard Worker     // Check if our pointer is truly misaligned
1102*89c4ff92SAndroid Build Coastguard Worker     uintptr_t alignment = GetDataTypeSize(DataType::Float32);
1103*89c4ff92SAndroid Build Coastguard Worker     CHECK (reinterpret_cast<uintptr_t>(misalignedMemPtr) % alignment);
1104*89c4ff92SAndroid Build Coastguard Worker 
1105*89c4ff92SAndroid Build Coastguard Worker     // Creates structures for input & output
1106*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputData
1107*89c4ff92SAndroid Build Coastguard Worker     {
1108*89c4ff92SAndroid Build Coastguard Worker         1.0f, 2.0f, 3.0f, 4.0f
1109*89c4ff92SAndroid Build Coastguard Worker     };
1110*89c4ff92SAndroid Build Coastguard Worker 
1111*89c4ff92SAndroid Build Coastguard Worker     // Check our input buffer is aligned
1112*89c4ff92SAndroid Build Coastguard Worker     CHECK(!(reinterpret_cast<uintptr_t>(inputData.data()) % alignment));
1113*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expectedOutput
1114*89c4ff92SAndroid Build Coastguard Worker     {
1115*89c4ff92SAndroid Build Coastguard Worker          1.0f, 4.0f, 9.0f, 16.0f
1116*89c4ff92SAndroid Build Coastguard Worker     };
1117*89c4ff92SAndroid Build Coastguard Worker 
1118*89c4ff92SAndroid Build Coastguard Worker     INFO("Create Inference");
1119*89c4ff92SAndroid Build Coastguard Worker     InputTensors inputTensors
1120*89c4ff92SAndroid Build Coastguard Worker     {
1121*89c4ff92SAndroid Build Coastguard Worker         {0,armnn::ConstTensor(runtime->GetInputTensorInfo(netId, 0), inputData.data())},
1122*89c4ff92SAndroid Build Coastguard Worker     };
1123*89c4ff92SAndroid Build Coastguard Worker     OutputTensors outputTensors
1124*89c4ff92SAndroid Build Coastguard Worker     {
1125*89c4ff92SAndroid Build Coastguard Worker         {0,armnn::Tensor(runtime->GetOutputTensorInfo(netId, 0), misalignedMemPtr)}
1126*89c4ff92SAndroid Build Coastguard Worker     };
1127*89c4ff92SAndroid Build Coastguard Worker     runtime->GetProfiler(netId)->EnableProfiling(true);
1128*89c4ff92SAndroid Build Coastguard Worker     std::vector<ImportedInputId> importedInputIds =
1129*89c4ff92SAndroid Build Coastguard Worker         runtime->ImportInputs(netId, inputTensors, MemorySource::Malloc);
1130*89c4ff92SAndroid Build Coastguard Worker     CHECK(importedInputIds.size() == 1);
1131*89c4ff92SAndroid Build Coastguard Worker     // We expect this to fail.
1132*89c4ff92SAndroid Build Coastguard Worker     std::vector<ImportedOutputId> importedOutputIds =
1133*89c4ff92SAndroid Build Coastguard Worker         runtime->ImportOutputs(netId, outputTensors, MemorySource::Malloc);
1134*89c4ff92SAndroid Build Coastguard Worker     CHECK(importedOutputIds.size() == 0);
1135*89c4ff92SAndroid Build Coastguard Worker 
1136*89c4ff92SAndroid Build Coastguard Worker     // Even if importing the output failed we still expect to be able to get it to work.
1137*89c4ff92SAndroid Build Coastguard Worker     runtime->EnqueueWorkload(netId, InputTensors(), outputTensors, importedInputIds, importedOutputIds);
1138*89c4ff92SAndroid Build Coastguard Worker 
1139*89c4ff92SAndroid Build Coastguard Worker     // Retrieve the Profiler.Print() output to get the workload execution
1140*89c4ff92SAndroid Build Coastguard Worker     ProfilerManager& profilerManager = armnn::ProfilerManager::GetInstance();
1141*89c4ff92SAndroid Build Coastguard Worker     std::stringstream ss;
1142*89c4ff92SAndroid Build Coastguard Worker     profilerManager.GetProfiler()->Print(ss);
1143*89c4ff92SAndroid Build Coastguard Worker     std::string dump = ss.str();
1144*89c4ff92SAndroid Build Coastguard Worker 
1145*89c4ff92SAndroid Build Coastguard Worker     // GpuAcc is a different case to CpuRef and CpuAcc, it doesn't use the buffer directly but instead maps it to a
1146*89c4ff92SAndroid Build Coastguard Worker     // new set of addresses within Gpu Memory. This will almost always be auto-aligned, so we don't need to check
1147*89c4ff92SAndroid Build Coastguard Worker     // for imports/copies. Only that the output is correct.
1148*89c4ff92SAndroid Build Coastguard Worker     if (backends[0] != Compute::GpuAcc)
1149*89c4ff92SAndroid Build Coastguard Worker     {
1150*89c4ff92SAndroid Build Coastguard Worker         // Even though we Imported the Input we still shouldn't have a SyncMemGeneric
1151*89c4ff92SAndroid Build Coastguard Worker         int count = SubStringCounter(dump, "SyncMemGeneric");
1152*89c4ff92SAndroid Build Coastguard Worker         CHECK(count == 0);
1153*89c4ff92SAndroid Build Coastguard Worker         // Should only be 1 CopyMemGeneric as we copied the input
1154*89c4ff92SAndroid Build Coastguard Worker         count = SubStringCounter(dump, "CopyMemGeneric");
1155*89c4ff92SAndroid Build Coastguard Worker         if (backends[0] == Compute::CpuAcc)
1156*89c4ff92SAndroid Build Coastguard Worker         {
1157*89c4ff92SAndroid Build Coastguard Worker             // Reconfigure has not been implemented for CpuAcc so it will always copy, this will break whenever
1158*89c4ff92SAndroid Build Coastguard Worker             // reconfigure is implemented
1159*89c4ff92SAndroid Build Coastguard Worker             CHECK(count == 2);
1160*89c4ff92SAndroid Build Coastguard Worker         }
1161*89c4ff92SAndroid Build Coastguard Worker         else
1162*89c4ff92SAndroid Build Coastguard Worker         {
1163*89c4ff92SAndroid Build Coastguard Worker             CHECK(count == 1);
1164*89c4ff92SAndroid Build Coastguard Worker         }
1165*89c4ff92SAndroid Build Coastguard Worker         // Check the output is correct
1166*89c4ff92SAndroid Build Coastguard Worker     }
1167*89c4ff92SAndroid Build Coastguard Worker     unsigned int index = 0;
1168*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> outputData(expectedOutput.size(), 0);
1169*89c4ff92SAndroid Build Coastguard Worker     std::memcpy(outputData.data(), misalignedMemPtr, expectedOutput.size() * sizeof(float));
1170*89c4ff92SAndroid Build Coastguard Worker     for (auto outputValue : expectedOutput)
1171*89c4ff92SAndroid Build Coastguard Worker     {
1172*89c4ff92SAndroid Build Coastguard Worker         CHECK(outputValue == outputData[index]);
1173*89c4ff92SAndroid Build Coastguard Worker         ++index;
1174*89c4ff92SAndroid Build Coastguard Worker     }
1175*89c4ff92SAndroid Build Coastguard Worker     std::free(memPtr);
1176*89c4ff92SAndroid Build Coastguard Worker }
1177*89c4ff92SAndroid Build Coastguard Worker 
ForceImportWithMisalignedInputAndOutputBuffersEndToEndTest(std::vector<BackendId> backends)1178*89c4ff92SAndroid Build Coastguard Worker inline void ForceImportWithMisalignedInputAndOutputBuffersEndToEndTest(std::vector<BackendId> backends)
1179*89c4ff92SAndroid Build Coastguard Worker {
1180*89c4ff92SAndroid Build Coastguard Worker     /**
1181*89c4ff92SAndroid Build Coastguard Worker      * This test is similar to the Import tests above, we create a network with a square function and pass in a vector
1182*89c4ff92SAndroid Build Coastguard Worker      * with 4 floats, square them. and validate the output. We then check the profiling logs to see if input/output
1183*89c4ff92SAndroid Build Coastguard Worker      * tensors are copied (CopyMemGeneric) or imported (SyncMemGeneric)
1184*89c4ff92SAndroid Build Coastguard Worker      * In this case all inputs and outputs should be copied
1185*89c4ff92SAndroid Build Coastguard Worker      */
1186*89c4ff92SAndroid Build Coastguard Worker     using namespace armnn;
1187*89c4ff92SAndroid Build Coastguard Worker 
1188*89c4ff92SAndroid Build Coastguard Worker     IRuntime::CreationOptions options;
1189*89c4ff92SAndroid Build Coastguard Worker     IRuntimePtr runtime(IRuntime::Create(options));
1190*89c4ff92SAndroid Build Coastguard Worker 
1191*89c4ff92SAndroid Build Coastguard Worker     // Builds up the structure of the network.
1192*89c4ff92SAndroid Build Coastguard Worker     INetworkPtr net(INetwork::Create());
1193*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* input = net->AddInputLayer(0);
1194*89c4ff92SAndroid Build Coastguard Worker 
1195*89c4ff92SAndroid Build Coastguard Worker     ActivationDescriptor descriptor;
1196*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_Function = ActivationFunction::Square;
1197*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* activationLayer = net->AddActivationLayer(descriptor);
1198*89c4ff92SAndroid Build Coastguard Worker 
1199*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* output = net->AddOutputLayer(0);
1200*89c4ff92SAndroid Build Coastguard Worker 
1201*89c4ff92SAndroid Build Coastguard Worker     input->GetOutputSlot(0).Connect(activationLayer->GetInputSlot(0));
1202*89c4ff92SAndroid Build Coastguard Worker     activationLayer->GetOutputSlot(0).Connect(output->GetInputSlot(0));
1203*89c4ff92SAndroid Build Coastguard Worker     input->GetOutputSlot(0).SetTensorInfo(TensorInfo({ 1, 1, 1, 4 }, DataType::Float32, 0.0f, 0, true));
1204*89c4ff92SAndroid Build Coastguard Worker     activationLayer->GetOutputSlot(0).SetTensorInfo(TensorInfo({ 1, 1, 1, 4 }, DataType::Float32));
1205*89c4ff92SAndroid Build Coastguard Worker 
1206*89c4ff92SAndroid Build Coastguard Worker     IOptimizedNetworkPtr optNet = Optimize(*net, backends, runtime->GetDeviceSpec());
1207*89c4ff92SAndroid Build Coastguard Worker     INFO("Load Network");
1208*89c4ff92SAndroid Build Coastguard Worker     // Load it into the runtime. It should pass.
1209*89c4ff92SAndroid Build Coastguard Worker     NetworkId netId;
1210*89c4ff92SAndroid Build Coastguard Worker     std::string errorMessage;
1211*89c4ff92SAndroid Build Coastguard Worker     INetworkProperties networkProperties(false, MemorySource::Undefined, MemorySource::Undefined);
1212*89c4ff92SAndroid Build Coastguard Worker     armnn::Status loadingStatus = runtime->LoadNetwork(netId, std::move(optNet), errorMessage, networkProperties);
1213*89c4ff92SAndroid Build Coastguard Worker     CHECK_MESSAGE(loadingStatus == Status::Success, errorMessage);
1214*89c4ff92SAndroid Build Coastguard Worker     INFO("Generate Data");
1215*89c4ff92SAndroid Build Coastguard Worker 
1216*89c4ff92SAndroid Build Coastguard Worker     // This code looks a little funky but the idea is to create a buffer of floats but offset by the size of a char
1217*89c4ff92SAndroid Build Coastguard Worker     // this will guarantee that the resultant buffer is misaligned and thus should always be copied.
1218*89c4ff92SAndroid Build Coastguard Worker     auto inputMemPtr = std::malloc(4 * sizeof(float) + sizeof(char));
1219*89c4ff92SAndroid Build Coastguard Worker     float* misalignedInputPtr = reinterpret_cast<float*>(reinterpret_cast<char*>(inputMemPtr) + 1);
1220*89c4ff92SAndroid Build Coastguard Worker 
1221*89c4ff92SAndroid Build Coastguard Worker     // Check if our pointer is truly misaligned
1222*89c4ff92SAndroid Build Coastguard Worker     uintptr_t alignment = GetDataTypeSize(DataType::Float32);
1223*89c4ff92SAndroid Build Coastguard Worker     CHECK (reinterpret_cast<uintptr_t>(misalignedInputPtr) % alignment);
1224*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputData
1225*89c4ff92SAndroid Build Coastguard Worker     {
1226*89c4ff92SAndroid Build Coastguard Worker          1.0f, 2.0f, 3.0f, 4.0f
1227*89c4ff92SAndroid Build Coastguard Worker     };
1228*89c4ff92SAndroid Build Coastguard Worker     std::memcpy(misalignedInputPtr, inputData.data(), 4*sizeof(float));
1229*89c4ff92SAndroid Build Coastguard Worker 
1230*89c4ff92SAndroid Build Coastguard Worker     auto outputMemPtr = std::malloc(4 * sizeof(float) + sizeof(char));
1231*89c4ff92SAndroid Build Coastguard Worker     float* misalignedOutputPtr = reinterpret_cast<float*>(reinterpret_cast<char*>(outputMemPtr) + 1);
1232*89c4ff92SAndroid Build Coastguard Worker 
1233*89c4ff92SAndroid Build Coastguard Worker     // Check if our pointer is truly misaligned
1234*89c4ff92SAndroid Build Coastguard Worker     CHECK (reinterpret_cast<uintptr_t>(misalignedOutputPtr) % alignment);
1235*89c4ff92SAndroid Build Coastguard Worker 
1236*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expectedOutput
1237*89c4ff92SAndroid Build Coastguard Worker     {
1238*89c4ff92SAndroid Build Coastguard Worker          1.0f, 4.0f, 9.0f, 16.0f
1239*89c4ff92SAndroid Build Coastguard Worker     };
1240*89c4ff92SAndroid Build Coastguard Worker 
1241*89c4ff92SAndroid Build Coastguard Worker     INFO("Create Inference");
1242*89c4ff92SAndroid Build Coastguard Worker     InputTensors inputTensors
1243*89c4ff92SAndroid Build Coastguard Worker     {
1244*89c4ff92SAndroid Build Coastguard Worker         {0,armnn::ConstTensor(runtime->GetInputTensorInfo(netId, 0), misalignedInputPtr)},
1245*89c4ff92SAndroid Build Coastguard Worker     };
1246*89c4ff92SAndroid Build Coastguard Worker     OutputTensors outputTensors
1247*89c4ff92SAndroid Build Coastguard Worker     {
1248*89c4ff92SAndroid Build Coastguard Worker         {0,armnn::Tensor(runtime->GetOutputTensorInfo(netId, 0), misalignedOutputPtr)}
1249*89c4ff92SAndroid Build Coastguard Worker     };
1250*89c4ff92SAndroid Build Coastguard Worker     runtime->GetProfiler(netId)->EnableProfiling(true);
1251*89c4ff92SAndroid Build Coastguard Worker     std::vector<ImportedInputId> importedInputIds =
1252*89c4ff92SAndroid Build Coastguard Worker         runtime->ImportInputs(netId, inputTensors, MemorySource::Malloc);
1253*89c4ff92SAndroid Build Coastguard Worker     // Import should have failed.
1254*89c4ff92SAndroid Build Coastguard Worker     CHECK(importedInputIds.size() == 0);
1255*89c4ff92SAndroid Build Coastguard Worker     std::vector<ImportedOutputId> importedOutputIds =
1256*89c4ff92SAndroid Build Coastguard Worker         runtime->ImportOutputs(netId, outputTensors, MemorySource::Malloc);
1257*89c4ff92SAndroid Build Coastguard Worker     // Import should have failed.
1258*89c4ff92SAndroid Build Coastguard Worker     CHECK(importedOutputIds.size() == 0);
1259*89c4ff92SAndroid Build Coastguard Worker 
1260*89c4ff92SAndroid Build Coastguard Worker     // Do the inference and force the import as the memory is misaligned.
1261*89c4ff92SAndroid Build Coastguard Worker     runtime->EnqueueWorkload(netId, inputTensors, outputTensors, importedInputIds, importedOutputIds);
1262*89c4ff92SAndroid Build Coastguard Worker 
1263*89c4ff92SAndroid Build Coastguard Worker     // Retrieve the Profiler.Print() output to get the workload execution
1264*89c4ff92SAndroid Build Coastguard Worker     ProfilerManager& profilerManager = armnn::ProfilerManager::GetInstance();
1265*89c4ff92SAndroid Build Coastguard Worker     std::stringstream ss;
1266*89c4ff92SAndroid Build Coastguard Worker     profilerManager.GetProfiler()->Print(ss);
1267*89c4ff92SAndroid Build Coastguard Worker     std::string dump = ss.str();
1268*89c4ff92SAndroid Build Coastguard Worker 
1269*89c4ff92SAndroid Build Coastguard Worker     // GpuAcc is a different case to CpuRef and CpuAcc, it doesn't use the buffer directly but instead maps it to a
1270*89c4ff92SAndroid Build Coastguard Worker     // new set of addresses within Gpu Memory. This will almost always be auto-aligned, so we don't need to check
1271*89c4ff92SAndroid Build Coastguard Worker     // for imports/copies. Only that the output is correct.
1272*89c4ff92SAndroid Build Coastguard Worker     if (backends[0] != Compute::GpuAcc)
1273*89c4ff92SAndroid Build Coastguard Worker     {
1274*89c4ff92SAndroid Build Coastguard Worker         // We can only copy so there should be no SyncMemGeneric
1275*89c4ff92SAndroid Build Coastguard Worker         int count = SubStringCounter(dump, "SyncMemGeneric");
1276*89c4ff92SAndroid Build Coastguard Worker         CHECK(count == 0);
1277*89c4ff92SAndroid Build Coastguard Worker         // Should only be CopyMemGeneric workloads as we copied all buffers
1278*89c4ff92SAndroid Build Coastguard Worker         count = SubStringCounter(dump, "CopyMemGeneric");
1279*89c4ff92SAndroid Build Coastguard Worker         CHECK(count == 2);
1280*89c4ff92SAndroid Build Coastguard Worker     }
1281*89c4ff92SAndroid Build Coastguard Worker     // Check the output is correct
1282*89c4ff92SAndroid Build Coastguard Worker     unsigned int index = 0;
1283*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> outputData(expectedOutput.size(), 0);
1284*89c4ff92SAndroid Build Coastguard Worker     std::memcpy(outputData.data(), misalignedOutputPtr, expectedOutput.size() * sizeof(float));
1285*89c4ff92SAndroid Build Coastguard Worker     for (auto expectedValue : expectedOutput)
1286*89c4ff92SAndroid Build Coastguard Worker     {
1287*89c4ff92SAndroid Build Coastguard Worker         CHECK(expectedValue == outputData[index]);
1288*89c4ff92SAndroid Build Coastguard Worker         ++index;
1289*89c4ff92SAndroid Build Coastguard Worker     }
1290*89c4ff92SAndroid Build Coastguard Worker     std::free(inputMemPtr);
1291*89c4ff92SAndroid Build Coastguard Worker     std::free(outputMemPtr);
1292*89c4ff92SAndroid Build Coastguard Worker }
1293*89c4ff92SAndroid Build Coastguard Worker 
ForceImportRepeatedInferencesEndToEndTest(std::vector<BackendId> backends)1294*89c4ff92SAndroid Build Coastguard Worker inline void ForceImportRepeatedInferencesEndToEndTest(std::vector<BackendId> backends)
1295*89c4ff92SAndroid Build Coastguard Worker {
1296*89c4ff92SAndroid Build Coastguard Worker     /**
1297*89c4ff92SAndroid Build Coastguard Worker      * This test is similar to the Import tests above, we create a network with a square function and pass in a vector
1298*89c4ff92SAndroid Build Coastguard Worker      * with 4 floats, square them. and validate the output. We then check the profiling logs to see if input/output
1299*89c4ff92SAndroid Build Coastguard Worker      * tensors are copied (CopyMemGeneric) or imported (SyncMemGeneric)
1300*89c4ff92SAndroid Build Coastguard Worker      * In this we create some aligned buffers, import them into a network and validate the output and number of
1301*89c4ff92SAndroid Build Coastguard Worker      * SynMemGeneric/CopyMemgeneric. Then we try the same network again with misaligned buffers to make sure it falls
1302*89c4ff92SAndroid Build Coastguard Worker      * back to copying correctly.
1303*89c4ff92SAndroid Build Coastguard Worker      */
1304*89c4ff92SAndroid Build Coastguard Worker     using namespace armnn;
1305*89c4ff92SAndroid Build Coastguard Worker 
1306*89c4ff92SAndroid Build Coastguard Worker     IRuntime::CreationOptions options;
1307*89c4ff92SAndroid Build Coastguard Worker     IRuntimePtr runtime(IRuntime::Create(options));
1308*89c4ff92SAndroid Build Coastguard Worker 
1309*89c4ff92SAndroid Build Coastguard Worker     // Builds up the structure of the network.
1310*89c4ff92SAndroid Build Coastguard Worker     INetworkPtr net(INetwork::Create());
1311*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* input = net->AddInputLayer(0);
1312*89c4ff92SAndroid Build Coastguard Worker 
1313*89c4ff92SAndroid Build Coastguard Worker     ActivationDescriptor descriptor;
1314*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_Function = ActivationFunction::Square;
1315*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* activationLayer = net->AddActivationLayer(descriptor);
1316*89c4ff92SAndroid Build Coastguard Worker 
1317*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* output = net->AddOutputLayer(0);
1318*89c4ff92SAndroid Build Coastguard Worker 
1319*89c4ff92SAndroid Build Coastguard Worker     input->GetOutputSlot(0).Connect(activationLayer->GetInputSlot(0));
1320*89c4ff92SAndroid Build Coastguard Worker     activationLayer->GetOutputSlot(0).Connect(output->GetInputSlot(0));
1321*89c4ff92SAndroid Build Coastguard Worker     input->GetOutputSlot(0).SetTensorInfo(TensorInfo({ 1, 1, 1, 4 }, DataType::Float32, 0.0f, 0, true));
1322*89c4ff92SAndroid Build Coastguard Worker     activationLayer->GetOutputSlot(0).SetTensorInfo(TensorInfo({ 1, 1, 1, 4 }, DataType::Float32));
1323*89c4ff92SAndroid Build Coastguard Worker 
1324*89c4ff92SAndroid Build Coastguard Worker     IOptimizedNetworkPtr optNet = Optimize(*net, backends, runtime->GetDeviceSpec());
1325*89c4ff92SAndroid Build Coastguard Worker     INFO("Load Network");
1326*89c4ff92SAndroid Build Coastguard Worker     // Load it into the runtime. It should pass.
1327*89c4ff92SAndroid Build Coastguard Worker     NetworkId netId;
1328*89c4ff92SAndroid Build Coastguard Worker     std::string errorMessage;
1329*89c4ff92SAndroid Build Coastguard Worker     INetworkProperties networkProperties(false, MemorySource::Undefined, MemorySource::Undefined);
1330*89c4ff92SAndroid Build Coastguard Worker     armnn::Status loadingStatus = runtime->LoadNetwork(netId, std::move(optNet), errorMessage, networkProperties);
1331*89c4ff92SAndroid Build Coastguard Worker     CHECK_MESSAGE(loadingStatus == Status::Success, errorMessage);
1332*89c4ff92SAndroid Build Coastguard Worker     INFO("Generate Data");
1333*89c4ff92SAndroid Build Coastguard Worker 
1334*89c4ff92SAndroid Build Coastguard Worker     // Creates structures for input & output
1335*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputData
1336*89c4ff92SAndroid Build Coastguard Worker     {
1337*89c4ff92SAndroid Build Coastguard Worker         1.0f, 2.0f, 3.0f, 4.0f
1338*89c4ff92SAndroid Build Coastguard Worker     };
1339*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> outputData(4);
1340*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expectedOutput
1341*89c4ff92SAndroid Build Coastguard Worker     {
1342*89c4ff92SAndroid Build Coastguard Worker          1.0f, 4.0f, 9.0f, 16.0f
1343*89c4ff92SAndroid Build Coastguard Worker     };
1344*89c4ff92SAndroid Build Coastguard Worker 
1345*89c4ff92SAndroid Build Coastguard Worker     // Check our input and output pointers are actually aligned
1346*89c4ff92SAndroid Build Coastguard Worker     uintptr_t alignment = GetDataTypeSize(DataType::Float32);
1347*89c4ff92SAndroid Build Coastguard Worker     CHECK(!(reinterpret_cast<uintptr_t>(inputData.data()) % alignment));
1348*89c4ff92SAndroid Build Coastguard Worker     CHECK(!(reinterpret_cast<uintptr_t>(outputData.data()) % alignment));
1349*89c4ff92SAndroid Build Coastguard Worker 
1350*89c4ff92SAndroid Build Coastguard Worker     INFO("Create Inference");
1351*89c4ff92SAndroid Build Coastguard Worker     InputTensors inputTensors
1352*89c4ff92SAndroid Build Coastguard Worker     {
1353*89c4ff92SAndroid Build Coastguard Worker         {0,armnn::ConstTensor(runtime->GetInputTensorInfo(netId, 0), inputData.data())},
1354*89c4ff92SAndroid Build Coastguard Worker     };
1355*89c4ff92SAndroid Build Coastguard Worker     OutputTensors outputTensors
1356*89c4ff92SAndroid Build Coastguard Worker     {
1357*89c4ff92SAndroid Build Coastguard Worker         {0,armnn::Tensor(runtime->GetOutputTensorInfo(netId, 0), outputData.data())}
1358*89c4ff92SAndroid Build Coastguard Worker     };
1359*89c4ff92SAndroid Build Coastguard Worker 
1360*89c4ff92SAndroid Build Coastguard Worker     runtime->GetProfiler(netId)->EnableProfiling(true);
1361*89c4ff92SAndroid Build Coastguard Worker     std::vector<ImportedInputId> importedInputIds =
1362*89c4ff92SAndroid Build Coastguard Worker         runtime->ImportInputs(netId, inputTensors, MemorySource::Malloc);
1363*89c4ff92SAndroid Build Coastguard Worker     CHECK(importedInputIds.size() == 1);
1364*89c4ff92SAndroid Build Coastguard Worker     std::vector<ImportedOutputId> importedOutputIds =
1365*89c4ff92SAndroid Build Coastguard Worker         runtime->ImportOutputs(netId, outputTensors, MemorySource::Malloc);
1366*89c4ff92SAndroid Build Coastguard Worker     CHECK(importedOutputIds.size() == 1);
1367*89c4ff92SAndroid Build Coastguard Worker     // Do the inference and force the import as the memory is aligned.
1368*89c4ff92SAndroid Build Coastguard Worker     runtime->EnqueueWorkload(netId, InputTensors(), OutputTensors(), importedInputIds, importedOutputIds);
1369*89c4ff92SAndroid Build Coastguard Worker 
1370*89c4ff92SAndroid Build Coastguard Worker     // Retrieve the Profiler.AnalyzeEventsAndWriteResults() output to get the workload execution
1371*89c4ff92SAndroid Build Coastguard Worker     ProfilerManager& profilerManager = armnn::ProfilerManager::GetInstance();
1372*89c4ff92SAndroid Build Coastguard Worker     std::stringstream ss;
1373*89c4ff92SAndroid Build Coastguard Worker     profilerManager.GetProfiler()->AnalyzeEventsAndWriteResults(ss);
1374*89c4ff92SAndroid Build Coastguard Worker     std::string dump = ss.str();
1375*89c4ff92SAndroid Build Coastguard Worker 
1376*89c4ff92SAndroid Build Coastguard Worker     if (backends[0] == Compute::CpuAcc)
1377*89c4ff92SAndroid Build Coastguard Worker     {
1378*89c4ff92SAndroid Build Coastguard Worker         // Reconfigure has not been implemented for CpuAcc so it will always copy, this will break whenever
1379*89c4ff92SAndroid Build Coastguard Worker         // reconfigure is implemented
1380*89c4ff92SAndroid Build Coastguard Worker         int count = SubStringCounter(dump, "SyncMemGeneric");
1381*89c4ff92SAndroid Build Coastguard Worker         CHECK(count == 0);
1382*89c4ff92SAndroid Build Coastguard Worker         // Should be 2 CopyMemGeneric workloads
1383*89c4ff92SAndroid Build Coastguard Worker         count = SubStringCounter(dump, "CopyMemGeneric");
1384*89c4ff92SAndroid Build Coastguard Worker         CHECK(count >= 1);
1385*89c4ff92SAndroid Build Coastguard Worker     }
1386*89c4ff92SAndroid Build Coastguard Worker     else
1387*89c4ff92SAndroid Build Coastguard Worker     {
1388*89c4ff92SAndroid Build Coastguard Worker         // Check there is at least 1 SyncMemGeneric workload as we exported
1389*89c4ff92SAndroid Build Coastguard Worker         int count = SubStringCounter(dump, "SyncMemGeneric");
1390*89c4ff92SAndroid Build Coastguard Worker         CHECK(count >= 1);
1391*89c4ff92SAndroid Build Coastguard Worker         // Shouldn't be any CopyMemGeneric workloads
1392*89c4ff92SAndroid Build Coastguard Worker         count = SubStringCounter(dump, "CopyMemGeneric");
1393*89c4ff92SAndroid Build Coastguard Worker         CHECK(count == 0);
1394*89c4ff92SAndroid Build Coastguard Worker     }
1395*89c4ff92SAndroid Build Coastguard Worker     // Check the output is correct
1396*89c4ff92SAndroid Build Coastguard Worker     CHECK(std::equal(outputData.begin(), outputData.end(), expectedOutput.begin(), expectedOutput.end()));
1397*89c4ff92SAndroid Build Coastguard Worker 
1398*89c4ff92SAndroid Build Coastguard Worker     // This code looks a little funky but the idea is to create a buffer of floats but offset by the size of a char
1399*89c4ff92SAndroid Build Coastguard Worker     // this will guarantee that the resultant buffer is misaligned and thus should always be copied.
1400*89c4ff92SAndroid Build Coastguard Worker     auto inputMemPtr = std::malloc(4 * sizeof(float) + sizeof(char));
1401*89c4ff92SAndroid Build Coastguard Worker     float* misalignedInputPtr = reinterpret_cast<float*>(reinterpret_cast<char*>(inputMemPtr) + 1);
1402*89c4ff92SAndroid Build Coastguard Worker 
1403*89c4ff92SAndroid Build Coastguard Worker     // Check if our pointer is truly misaligned
1404*89c4ff92SAndroid Build Coastguard Worker     CHECK (reinterpret_cast<uintptr_t>(misalignedInputPtr) % alignment);
1405*89c4ff92SAndroid Build Coastguard Worker 
1406*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputValues
1407*89c4ff92SAndroid Build Coastguard Worker     {
1408*89c4ff92SAndroid Build Coastguard Worker          2.0f, 3.0f, 4.0f, 5.0f
1409*89c4ff92SAndroid Build Coastguard Worker     };
1410*89c4ff92SAndroid Build Coastguard Worker 
1411*89c4ff92SAndroid Build Coastguard Worker     std::memcpy(misalignedInputPtr, inputValues.data(), inputValues.size()*sizeof(float));
1412*89c4ff92SAndroid Build Coastguard Worker 
1413*89c4ff92SAndroid Build Coastguard Worker     auto outputMemPtr = std::malloc(4 * sizeof(float) + sizeof(char));
1414*89c4ff92SAndroid Build Coastguard Worker     float* misalignedOutputPtr = reinterpret_cast<float*>(reinterpret_cast<char*>(outputMemPtr) + 1);
1415*89c4ff92SAndroid Build Coastguard Worker 
1416*89c4ff92SAndroid Build Coastguard Worker     // Check if our pointer is truly misaligned
1417*89c4ff92SAndroid Build Coastguard Worker     CHECK (reinterpret_cast<uintptr_t>(misalignedOutputPtr) % alignment);
1418*89c4ff92SAndroid Build Coastguard Worker 
1419*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expectedMisalignedOutput
1420*89c4ff92SAndroid Build Coastguard Worker     {
1421*89c4ff92SAndroid Build Coastguard Worker          4.0f, 9.0f, 16.0f, 25.0f
1422*89c4ff92SAndroid Build Coastguard Worker     };
1423*89c4ff92SAndroid Build Coastguard Worker 
1424*89c4ff92SAndroid Build Coastguard Worker     INFO("Create Second Inference");
1425*89c4ff92SAndroid Build Coastguard Worker     InputTensors inputTensorsMisaligned
1426*89c4ff92SAndroid Build Coastguard Worker     {
1427*89c4ff92SAndroid Build Coastguard Worker         {0,armnn::ConstTensor(runtime->GetInputTensorInfo(netId, 0), misalignedInputPtr)},
1428*89c4ff92SAndroid Build Coastguard Worker     };
1429*89c4ff92SAndroid Build Coastguard Worker     OutputTensors outputTensorsMisaligned
1430*89c4ff92SAndroid Build Coastguard Worker     {
1431*89c4ff92SAndroid Build Coastguard Worker         {0,armnn::Tensor(runtime->GetOutputTensorInfo(netId, 0), misalignedOutputPtr)}
1432*89c4ff92SAndroid Build Coastguard Worker     };
1433*89c4ff92SAndroid Build Coastguard Worker     importedInputIds = runtime->ImportInputs(netId, inputTensorsMisaligned, MemorySource::Malloc);
1434*89c4ff92SAndroid Build Coastguard Worker     // Import should fail.
1435*89c4ff92SAndroid Build Coastguard Worker     CHECK(importedInputIds.size() == 0);
1436*89c4ff92SAndroid Build Coastguard Worker     importedOutputIds = runtime->ImportOutputs(netId, outputTensorsMisaligned, MemorySource::Malloc);
1437*89c4ff92SAndroid Build Coastguard Worker     // Import should fail.
1438*89c4ff92SAndroid Build Coastguard Worker     CHECK(importedOutputIds.size() == 0);
1439*89c4ff92SAndroid Build Coastguard Worker 
1440*89c4ff92SAndroid Build Coastguard Worker     // Do the inference and force the import as the memory is misaligned.
1441*89c4ff92SAndroid Build Coastguard Worker     runtime->EnqueueWorkload(netId,
1442*89c4ff92SAndroid Build Coastguard Worker                              inputTensorsMisaligned,
1443*89c4ff92SAndroid Build Coastguard Worker                              outputTensorsMisaligned,
1444*89c4ff92SAndroid Build Coastguard Worker                              importedInputIds,
1445*89c4ff92SAndroid Build Coastguard Worker                              importedOutputIds);
1446*89c4ff92SAndroid Build Coastguard Worker 
1447*89c4ff92SAndroid Build Coastguard Worker     // Retrieve the Profiler.AnalyzeEventsAndWriteResults() output to get the workload execution
1448*89c4ff92SAndroid Build Coastguard Worker     // We need to use AnalyzeEventsAndWriteResults here to make sure the second inference has been profiled
1449*89c4ff92SAndroid Build Coastguard Worker     profilerManager.GetProfiler()->AnalyzeEventsAndWriteResults(ss);
1450*89c4ff92SAndroid Build Coastguard Worker     dump = ss.str();
1451*89c4ff92SAndroid Build Coastguard Worker 
1452*89c4ff92SAndroid Build Coastguard Worker     // GpuAcc is a different case to CpuRef and CpuAcc, it doesn't use the buffer directly but instead maps it to a
1453*89c4ff92SAndroid Build Coastguard Worker     // new set of addresses within Gpu Memory. This will almost always be auto-aligned, so we don't need to check
1454*89c4ff92SAndroid Build Coastguard Worker     // for imports/copies. Only that the output is correct.
1455*89c4ff92SAndroid Build Coastguard Worker     if (backends[0] != Compute::GpuAcc)
1456*89c4ff92SAndroid Build Coastguard Worker     {
1457*89c4ff92SAndroid Build Coastguard Worker         // The SyncMemGeneric will still be in the profiling log from the first inference
1458*89c4ff92SAndroid Build Coastguard Worker         int count = SubStringCounter(dump, "SyncMemGeneric");
1459*89c4ff92SAndroid Build Coastguard Worker         CHECK(count >= 1);
1460*89c4ff92SAndroid Build Coastguard Worker         // We should now see CopyMemGeneric workloads as we copied all buffers
1461*89c4ff92SAndroid Build Coastguard Worker         count = SubStringCounter(dump, "CopyMemGeneric");
1462*89c4ff92SAndroid Build Coastguard Worker         CHECK(count >= 1);
1463*89c4ff92SAndroid Build Coastguard Worker     }
1464*89c4ff92SAndroid Build Coastguard Worker     // Check the output is correct
1465*89c4ff92SAndroid Build Coastguard Worker     unsigned int index = 0;
1466*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> alignedOutputData(expectedMisalignedOutput.size(), 0);
1467*89c4ff92SAndroid Build Coastguard Worker     std::memcpy(alignedOutputData.data(), misalignedOutputPtr, expectedMisalignedOutput.size() * sizeof(float));
1468*89c4ff92SAndroid Build Coastguard Worker     for (auto outputValue : expectedMisalignedOutput)
1469*89c4ff92SAndroid Build Coastguard Worker     {
1470*89c4ff92SAndroid Build Coastguard Worker         CHECK(outputValue == alignedOutputData[index]);
1471*89c4ff92SAndroid Build Coastguard Worker         ++index;
1472*89c4ff92SAndroid Build Coastguard Worker     }
1473*89c4ff92SAndroid Build Coastguard Worker     // Clean up to avoid interfering with other tests
1474*89c4ff92SAndroid Build Coastguard Worker     runtime->UnloadNetwork(netId);
1475*89c4ff92SAndroid Build Coastguard Worker     std::free(inputMemPtr);
1476*89c4ff92SAndroid Build Coastguard Worker     std::free(outputMemPtr);
1477*89c4ff92SAndroid Build Coastguard Worker }
1478*89c4ff92SAndroid Build Coastguard Worker 
1479*89c4ff92SAndroid Build Coastguard Worker 
ForceImportRepeatedInferencesInvertedEndToEndTest(std::vector<BackendId> backends)1480*89c4ff92SAndroid Build Coastguard Worker inline void ForceImportRepeatedInferencesInvertedEndToEndTest(std::vector<BackendId> backends)
1481*89c4ff92SAndroid Build Coastguard Worker {
1482*89c4ff92SAndroid Build Coastguard Worker     /**
1483*89c4ff92SAndroid Build Coastguard Worker      * This test is similar to the Import tests above, we create a network with a square function and pass in a vector
1484*89c4ff92SAndroid Build Coastguard Worker      * with 4 floats, square them. and validate the output. We then check the profiling logs to see if input/output
1485*89c4ff92SAndroid Build Coastguard Worker      * tensors are copied (CopyMemGeneric) or imported (SyncMemGeneric)
1486*89c4ff92SAndroid Build Coastguard Worker      * In this we create some misaligned buffers, copy them into a network and validate the output and number of
1487*89c4ff92SAndroid Build Coastguard Worker      * SynMemGeneric/CopyMemgeneric. Then we try the same network again with aligned buffers to make sure it switches
1488*89c4ff92SAndroid Build Coastguard Worker      * to importing correctly.
1489*89c4ff92SAndroid Build Coastguard Worker      */
1490*89c4ff92SAndroid Build Coastguard Worker     using namespace armnn;
1491*89c4ff92SAndroid Build Coastguard Worker 
1492*89c4ff92SAndroid Build Coastguard Worker     IRuntime::CreationOptions options;
1493*89c4ff92SAndroid Build Coastguard Worker     IRuntimePtr runtime(IRuntime::Create(options));
1494*89c4ff92SAndroid Build Coastguard Worker 
1495*89c4ff92SAndroid Build Coastguard Worker     // Builds up the structure of the network.
1496*89c4ff92SAndroid Build Coastguard Worker     INetworkPtr net(INetwork::Create());
1497*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* input = net->AddInputLayer(0);
1498*89c4ff92SAndroid Build Coastguard Worker 
1499*89c4ff92SAndroid Build Coastguard Worker     ActivationDescriptor descriptor;
1500*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_Function = ActivationFunction::Square;
1501*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* activationLayer = net->AddActivationLayer(descriptor);
1502*89c4ff92SAndroid Build Coastguard Worker 
1503*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* output = net->AddOutputLayer(0);
1504*89c4ff92SAndroid Build Coastguard Worker 
1505*89c4ff92SAndroid Build Coastguard Worker     input->GetOutputSlot(0).Connect(activationLayer->GetInputSlot(0));
1506*89c4ff92SAndroid Build Coastguard Worker     activationLayer->GetOutputSlot(0).Connect(output->GetInputSlot(0));
1507*89c4ff92SAndroid Build Coastguard Worker     input->GetOutputSlot(0).SetTensorInfo(TensorInfo({ 1, 1, 1, 4 }, DataType::Float32, 0.0f, 0, true));
1508*89c4ff92SAndroid Build Coastguard Worker     activationLayer->GetOutputSlot(0).SetTensorInfo(TensorInfo({ 1, 1, 1, 4 }, DataType::Float32));
1509*89c4ff92SAndroid Build Coastguard Worker 
1510*89c4ff92SAndroid Build Coastguard Worker     IOptimizedNetworkPtr optNet = Optimize(*net, backends, runtime->GetDeviceSpec());
1511*89c4ff92SAndroid Build Coastguard Worker     INFO("Load Network");
1512*89c4ff92SAndroid Build Coastguard Worker     // Load it into the runtime. It should pass.
1513*89c4ff92SAndroid Build Coastguard Worker     NetworkId netId;
1514*89c4ff92SAndroid Build Coastguard Worker     std::string errorMessage;
1515*89c4ff92SAndroid Build Coastguard Worker     INetworkProperties networkProperties(false, MemorySource::Undefined, MemorySource::Undefined);
1516*89c4ff92SAndroid Build Coastguard Worker     armnn::Status loadingStatus = runtime->LoadNetwork(netId, std::move(optNet), errorMessage, networkProperties);
1517*89c4ff92SAndroid Build Coastguard Worker     CHECK_MESSAGE(loadingStatus == Status::Success, errorMessage);
1518*89c4ff92SAndroid Build Coastguard Worker     INFO("Generate Data");
1519*89c4ff92SAndroid Build Coastguard Worker 
1520*89c4ff92SAndroid Build Coastguard Worker     // This code looks a little funky but the idea is to create a buffer of floats but offset by the size of a char
1521*89c4ff92SAndroid Build Coastguard Worker     // this will guarantee that the resultant buffer is misaligned and thus should always be copied.
1522*89c4ff92SAndroid Build Coastguard Worker     auto inputMemPtr = std::malloc(4 * sizeof(float) + sizeof(char));
1523*89c4ff92SAndroid Build Coastguard Worker     float* misalignedInputPtr = reinterpret_cast<float*>(reinterpret_cast<char*>(inputMemPtr) + 1);
1524*89c4ff92SAndroid Build Coastguard Worker 
1525*89c4ff92SAndroid Build Coastguard Worker     // Check if our pointer is truly misaligned
1526*89c4ff92SAndroid Build Coastguard Worker     uintptr_t alignment = GetDataTypeSize(DataType::Float32);
1527*89c4ff92SAndroid Build Coastguard Worker     CHECK (reinterpret_cast<uintptr_t>(misalignedInputPtr) % alignment);
1528*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputValues
1529*89c4ff92SAndroid Build Coastguard Worker     {
1530*89c4ff92SAndroid Build Coastguard Worker          2.0f, 3.0f, 4.0f, 5.0f
1531*89c4ff92SAndroid Build Coastguard Worker     };
1532*89c4ff92SAndroid Build Coastguard Worker     std::memcpy(misalignedInputPtr, inputValues.data(), inputValues.size() * sizeof(float));
1533*89c4ff92SAndroid Build Coastguard Worker 
1534*89c4ff92SAndroid Build Coastguard Worker     auto outputMemPtr = std::malloc(4 * sizeof(float) + sizeof(char));
1535*89c4ff92SAndroid Build Coastguard Worker     float* misalignedOutputPtr = reinterpret_cast<float*>(reinterpret_cast<char*>(outputMemPtr) + 1);
1536*89c4ff92SAndroid Build Coastguard Worker 
1537*89c4ff92SAndroid Build Coastguard Worker     // Check if our pointer is truly misaligned
1538*89c4ff92SAndroid Build Coastguard Worker     CHECK (reinterpret_cast<uintptr_t>(misalignedOutputPtr) % alignment);
1539*89c4ff92SAndroid Build Coastguard Worker 
1540*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expectedMisalignedOutput
1541*89c4ff92SAndroid Build Coastguard Worker     {
1542*89c4ff92SAndroid Build Coastguard Worker          4.0f, 9.0f, 16.0f, 25.0f
1543*89c4ff92SAndroid Build Coastguard Worker     };
1544*89c4ff92SAndroid Build Coastguard Worker 
1545*89c4ff92SAndroid Build Coastguard Worker     INFO("Create Second Inference");
1546*89c4ff92SAndroid Build Coastguard Worker     InputTensors inputTensorsMisaligned
1547*89c4ff92SAndroid Build Coastguard Worker     {
1548*89c4ff92SAndroid Build Coastguard Worker         {0,armnn::ConstTensor(runtime->GetInputTensorInfo(netId, 0), misalignedInputPtr)},
1549*89c4ff92SAndroid Build Coastguard Worker     };
1550*89c4ff92SAndroid Build Coastguard Worker     OutputTensors outputTensorsMisaligned
1551*89c4ff92SAndroid Build Coastguard Worker     {
1552*89c4ff92SAndroid Build Coastguard Worker         {0,armnn::Tensor(runtime->GetOutputTensorInfo(netId, 0), misalignedOutputPtr)}
1553*89c4ff92SAndroid Build Coastguard Worker     };
1554*89c4ff92SAndroid Build Coastguard Worker     runtime->GetProfiler(netId)->EnableProfiling(true);
1555*89c4ff92SAndroid Build Coastguard Worker     std::vector<ImportedInputId>  importedInputIds =
1556*89c4ff92SAndroid Build Coastguard Worker         runtime->ImportInputs(netId, inputTensorsMisaligned, MemorySource::Malloc);
1557*89c4ff92SAndroid Build Coastguard Worker     // Import should fail.
1558*89c4ff92SAndroid Build Coastguard Worker     CHECK(importedInputIds.size() == 0);
1559*89c4ff92SAndroid Build Coastguard Worker     std::vector<ImportedOutputId> importedOutputIds =
1560*89c4ff92SAndroid Build Coastguard Worker         runtime->ImportOutputs(netId, outputTensorsMisaligned, MemorySource::Malloc);
1561*89c4ff92SAndroid Build Coastguard Worker     // Import should fail.
1562*89c4ff92SAndroid Build Coastguard Worker     CHECK(importedOutputIds.size() == 0);
1563*89c4ff92SAndroid Build Coastguard Worker 
1564*89c4ff92SAndroid Build Coastguard Worker     // Do the inference and force the import as the memory is misaligned.
1565*89c4ff92SAndroid Build Coastguard Worker     runtime->EnqueueWorkload(netId,
1566*89c4ff92SAndroid Build Coastguard Worker                              inputTensorsMisaligned,
1567*89c4ff92SAndroid Build Coastguard Worker                              outputTensorsMisaligned,
1568*89c4ff92SAndroid Build Coastguard Worker                              importedInputIds,
1569*89c4ff92SAndroid Build Coastguard Worker                              importedOutputIds);
1570*89c4ff92SAndroid Build Coastguard Worker 
1571*89c4ff92SAndroid Build Coastguard Worker     // Retrieve the Profiler.AnalyzeEventsAndWriteResults() output to get the workload execution
1572*89c4ff92SAndroid Build Coastguard Worker     ProfilerManager& profilerManager = armnn::ProfilerManager::GetInstance();
1573*89c4ff92SAndroid Build Coastguard Worker     std::stringstream ss;
1574*89c4ff92SAndroid Build Coastguard Worker     profilerManager.GetProfiler()->AnalyzeEventsAndWriteResults(ss);
1575*89c4ff92SAndroid Build Coastguard Worker     std::string dump = ss.str();
1576*89c4ff92SAndroid Build Coastguard Worker 
1577*89c4ff92SAndroid Build Coastguard Worker     // GpuAcc is a different case to CpuRef and CpuAcc, it doesn't use the buffer directly but instead maps it to a
1578*89c4ff92SAndroid Build Coastguard Worker     // new set of addresses within Gpu Memory. This will almost always be auto-aligned, so we don't need to check
1579*89c4ff92SAndroid Build Coastguard Worker     // for imports/copies. Only that the output is correct.
1580*89c4ff92SAndroid Build Coastguard Worker     if (backends[0] != Compute::GpuAcc)
1581*89c4ff92SAndroid Build Coastguard Worker     {
1582*89c4ff92SAndroid Build Coastguard Worker         // We can only copy so there should be no SyncMemGeneric
1583*89c4ff92SAndroid Build Coastguard Worker         int count = SubStringCounter(dump, "SyncMemGeneric");
1584*89c4ff92SAndroid Build Coastguard Worker         CHECK(count == 0);
1585*89c4ff92SAndroid Build Coastguard Worker         // Should only be CopyMemGeneric workloads as we copied all buffers
1586*89c4ff92SAndroid Build Coastguard Worker         count = SubStringCounter(dump, "CopyMemGeneric");
1587*89c4ff92SAndroid Build Coastguard Worker         CHECK(count >= 1);
1588*89c4ff92SAndroid Build Coastguard Worker     }
1589*89c4ff92SAndroid Build Coastguard Worker     // Check the output is correct
1590*89c4ff92SAndroid Build Coastguard Worker     unsigned int index = 0;
1591*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> alignedOutput(expectedMisalignedOutput.size());
1592*89c4ff92SAndroid Build Coastguard Worker     std::memcpy(alignedOutput.data(), misalignedOutputPtr, expectedMisalignedOutput.size()*sizeof(float));
1593*89c4ff92SAndroid Build Coastguard Worker     for (auto outputValue : expectedMisalignedOutput)
1594*89c4ff92SAndroid Build Coastguard Worker     {
1595*89c4ff92SAndroid Build Coastguard Worker         CHECK(outputValue == alignedOutput[index]);
1596*89c4ff92SAndroid Build Coastguard Worker         ++index;
1597*89c4ff92SAndroid Build Coastguard Worker     }
1598*89c4ff92SAndroid Build Coastguard Worker     std::free(inputMemPtr);
1599*89c4ff92SAndroid Build Coastguard Worker     std::free(outputMemPtr);
1600*89c4ff92SAndroid Build Coastguard Worker 
1601*89c4ff92SAndroid Build Coastguard Worker     // Creates structures for input & output
1602*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputData
1603*89c4ff92SAndroid Build Coastguard Worker     {
1604*89c4ff92SAndroid Build Coastguard Worker         1.0f, 2.0f, 3.0f, 4.0f
1605*89c4ff92SAndroid Build Coastguard Worker     };
1606*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> outputData(4);
1607*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expectedOutput
1608*89c4ff92SAndroid Build Coastguard Worker     {
1609*89c4ff92SAndroid Build Coastguard Worker          1.0f, 4.0f, 9.0f, 16.0f
1610*89c4ff92SAndroid Build Coastguard Worker     };
1611*89c4ff92SAndroid Build Coastguard Worker 
1612*89c4ff92SAndroid Build Coastguard Worker     // Check our input and output pointers are actually aligned
1613*89c4ff92SAndroid Build Coastguard Worker     CHECK(!(reinterpret_cast<uintptr_t>(inputData.data()) % alignment));
1614*89c4ff92SAndroid Build Coastguard Worker     CHECK(!(reinterpret_cast<uintptr_t>(outputData.data()) % alignment));
1615*89c4ff92SAndroid Build Coastguard Worker 
1616*89c4ff92SAndroid Build Coastguard Worker     INFO("Create Inference");
1617*89c4ff92SAndroid Build Coastguard Worker     InputTensors inputTensors
1618*89c4ff92SAndroid Build Coastguard Worker     {
1619*89c4ff92SAndroid Build Coastguard Worker         {0,armnn::ConstTensor(runtime->GetInputTensorInfo(netId, 0), inputData.data())},
1620*89c4ff92SAndroid Build Coastguard Worker     };
1621*89c4ff92SAndroid Build Coastguard Worker     OutputTensors outputTensors
1622*89c4ff92SAndroid Build Coastguard Worker     {
1623*89c4ff92SAndroid Build Coastguard Worker         {0,armnn::Tensor(runtime->GetOutputTensorInfo(netId, 0), outputData.data())}
1624*89c4ff92SAndroid Build Coastguard Worker     };
1625*89c4ff92SAndroid Build Coastguard Worker 
1626*89c4ff92SAndroid Build Coastguard Worker     importedInputIds = runtime->ImportInputs(netId, inputTensors, MemorySource::Malloc);
1627*89c4ff92SAndroid Build Coastguard Worker     CHECK(importedInputIds.size() == 1);
1628*89c4ff92SAndroid Build Coastguard Worker     importedOutputIds = runtime->ImportOutputs(netId, outputTensors, MemorySource::Malloc);
1629*89c4ff92SAndroid Build Coastguard Worker     CHECK(importedOutputIds.size() == 1);
1630*89c4ff92SAndroid Build Coastguard Worker     // Do the inference and force the import as the memory is aligned.
1631*89c4ff92SAndroid Build Coastguard Worker     runtime->EnqueueWorkload(netId, InputTensors(), OutputTensors(), importedInputIds, importedOutputIds);
1632*89c4ff92SAndroid Build Coastguard Worker 
1633*89c4ff92SAndroid Build Coastguard Worker     // Retrieve the Profiler.AnalyzeEventsAndWriteResults() output to get the workload execution
1634*89c4ff92SAndroid Build Coastguard Worker     // We need to use AnalyzeEventsAndWriteResults here to make sure the second inference has been profiled
1635*89c4ff92SAndroid Build Coastguard Worker     profilerManager.GetProfiler()->AnalyzeEventsAndWriteResults(ss);
1636*89c4ff92SAndroid Build Coastguard Worker     dump = ss.str();
1637*89c4ff92SAndroid Build Coastguard Worker 
1638*89c4ff92SAndroid Build Coastguard Worker     if (backends[0] == Compute::CpuAcc)
1639*89c4ff92SAndroid Build Coastguard Worker     {
1640*89c4ff92SAndroid Build Coastguard Worker         // Reconfigure has not been implemented for CpuAcc so it will always copy, this will break whenever
1641*89c4ff92SAndroid Build Coastguard Worker         // reconfigure is implemented
1642*89c4ff92SAndroid Build Coastguard Worker         int count = SubStringCounter(dump, "SyncMemGeneric");
1643*89c4ff92SAndroid Build Coastguard Worker         CHECK(count == 0);
1644*89c4ff92SAndroid Build Coastguard Worker         // Should be 2 CopyMemGeneric workloads
1645*89c4ff92SAndroid Build Coastguard Worker         count = SubStringCounter(dump, "CopyMemGeneric");
1646*89c4ff92SAndroid Build Coastguard Worker         CHECK(count >= 1);
1647*89c4ff92SAndroid Build Coastguard Worker     }
1648*89c4ff92SAndroid Build Coastguard Worker     else
1649*89c4ff92SAndroid Build Coastguard Worker     {
1650*89c4ff92SAndroid Build Coastguard Worker         // Repeated inferences make it difficult to check for an accurate count. So we just validate that we have a
1651*89c4ff92SAndroid Build Coastguard Worker         // SyncMemGeneric Workload when we previously didn't
1652*89c4ff92SAndroid Build Coastguard Worker         int count = SubStringCounter(dump, "SyncMemGeneric");
1653*89c4ff92SAndroid Build Coastguard Worker         CHECK(count >= 1);
1654*89c4ff92SAndroid Build Coastguard Worker         // Should still be some CopyMemGeneric Workloads from the last inference
1655*89c4ff92SAndroid Build Coastguard Worker         count = SubStringCounter(dump, "CopyMemGeneric");
1656*89c4ff92SAndroid Build Coastguard Worker         CHECK(count >= 1);
1657*89c4ff92SAndroid Build Coastguard Worker     }
1658*89c4ff92SAndroid Build Coastguard Worker     // Check the output is correct
1659*89c4ff92SAndroid Build Coastguard Worker     CHECK(std::equal(outputData.begin(), outputData.end(), expectedOutput.begin(), expectedOutput.end()));
1660*89c4ff92SAndroid Build Coastguard Worker     // Clean up to avoid interfering with other tests
1661*89c4ff92SAndroid Build Coastguard Worker     runtime->UnloadNetwork(netId);
1662*89c4ff92SAndroid Build Coastguard Worker }
1663*89c4ff92SAndroid Build Coastguard Worker 
1664*89c4ff92SAndroid Build Coastguard Worker } // anonymous namespace
1665