1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker
6*89c4ff92SAndroid Build Coastguard Worker #pragma once
7*89c4ff92SAndroid Build Coastguard Worker
8*89c4ff92SAndroid Build Coastguard Worker #include <ResolveType.hpp>
9*89c4ff92SAndroid Build Coastguard Worker
10*89c4ff92SAndroid Build Coastguard Worker #include <armnn/IWorkingMemHandle.hpp>
11*89c4ff92SAndroid Build Coastguard Worker #include <armnn/INetwork.hpp>
12*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Threadpool.hpp>
13*89c4ff92SAndroid Build Coastguard Worker #include <armnn/IAsyncExecutionCallback.hpp>
14*89c4ff92SAndroid Build Coastguard Worker
15*89c4ff92SAndroid Build Coastguard Worker #include <AsyncExecutionCallback.hpp>
16*89c4ff92SAndroid Build Coastguard Worker #include <CommonTestUtils.hpp>
17*89c4ff92SAndroid Build Coastguard Worker
18*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
19*89c4ff92SAndroid Build Coastguard Worker
20*89c4ff92SAndroid Build Coastguard Worker #include <vector>
21*89c4ff92SAndroid Build Coastguard Worker
22*89c4ff92SAndroid Build Coastguard Worker namespace armnn
23*89c4ff92SAndroid Build Coastguard Worker {
24*89c4ff92SAndroid Build Coastguard Worker
25*89c4ff92SAndroid Build Coastguard Worker namespace experimental
26*89c4ff92SAndroid Build Coastguard Worker {
27*89c4ff92SAndroid Build Coastguard Worker
28*89c4ff92SAndroid Build Coastguard Worker template<DataType ArmnnIType, DataType ArmnnOType,
29*89c4ff92SAndroid Build Coastguard Worker typename TInput = ResolveType <ArmnnIType>, typename TOutput = ResolveType <ArmnnOType>>
AsyncThreadedEndToEndTestImpl(INetworkPtr network,const std::vector<std::map<int,std::vector<TInput>>> & inputTensorData,const std::vector<std::map<int,std::vector<TOutput>>> & expectedOutputData,std::vector<BackendId> backends,const size_t numberOfInferences,float tolerance=0.000001f)30*89c4ff92SAndroid Build Coastguard Worker void AsyncThreadedEndToEndTestImpl(INetworkPtr network,
31*89c4ff92SAndroid Build Coastguard Worker const std::vector<std::map<int, std::vector<TInput>>>& inputTensorData,
32*89c4ff92SAndroid Build Coastguard Worker const std::vector<std::map<int, std::vector<TOutput>>>& expectedOutputData,
33*89c4ff92SAndroid Build Coastguard Worker std::vector<BackendId> backends,
34*89c4ff92SAndroid Build Coastguard Worker const size_t numberOfInferences,
35*89c4ff92SAndroid Build Coastguard Worker float tolerance = 0.000001f)
36*89c4ff92SAndroid Build Coastguard Worker {
37*89c4ff92SAndroid Build Coastguard Worker // Create Runtime in which test will run
38*89c4ff92SAndroid Build Coastguard Worker IRuntime::CreationOptions options;
39*89c4ff92SAndroid Build Coastguard Worker IRuntimePtr runtime(IRuntime::Create(options));
40*89c4ff92SAndroid Build Coastguard Worker
41*89c4ff92SAndroid Build Coastguard Worker // Optimize the Network
42*89c4ff92SAndroid Build Coastguard Worker IOptimizedNetworkPtr optNet = Optimize(*network, backends, runtime->GetDeviceSpec());
43*89c4ff92SAndroid Build Coastguard Worker
44*89c4ff92SAndroid Build Coastguard Worker // Creates AsyncNetwork
45*89c4ff92SAndroid Build Coastguard Worker NetworkId networkId = 0;
46*89c4ff92SAndroid Build Coastguard Worker std::string errorMessage;
47*89c4ff92SAndroid Build Coastguard Worker const INetworkProperties networkProperties(true, MemorySource::Undefined, MemorySource::Undefined);
48*89c4ff92SAndroid Build Coastguard Worker runtime->LoadNetwork(networkId, std::move(optNet), errorMessage, networkProperties);
49*89c4ff92SAndroid Build Coastguard Worker
50*89c4ff92SAndroid Build Coastguard Worker std::vector<InputTensors> inputTensorsVec;
51*89c4ff92SAndroid Build Coastguard Worker std::vector<OutputTensors> outputTensorsVec;
52*89c4ff92SAndroid Build Coastguard Worker std::vector<std::map<int, std::vector<TOutput>>> outputStorageVec;
53*89c4ff92SAndroid Build Coastguard Worker std::vector<std::unique_ptr<IWorkingMemHandle>> workingMemHandles;
54*89c4ff92SAndroid Build Coastguard Worker
55*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < numberOfInferences; ++i)
56*89c4ff92SAndroid Build Coastguard Worker {
57*89c4ff92SAndroid Build Coastguard Worker InputTensors inputTensors;
58*89c4ff92SAndroid Build Coastguard Worker OutputTensors outputTensors;
59*89c4ff92SAndroid Build Coastguard Worker outputStorageVec.emplace_back(std::map<int, std::vector<TOutput>>());
60*89c4ff92SAndroid Build Coastguard Worker
61*89c4ff92SAndroid Build Coastguard Worker inputTensors.reserve(inputTensorData.size());
62*89c4ff92SAndroid Build Coastguard Worker for (auto&& it : inputTensorData[i])
63*89c4ff92SAndroid Build Coastguard Worker {
64*89c4ff92SAndroid Build Coastguard Worker TensorInfo inputTensorInfo = runtime->GetInputTensorInfo(networkId, it.first);
65*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo.SetConstant(true);
66*89c4ff92SAndroid Build Coastguard Worker inputTensors.push_back({it.first,
67*89c4ff92SAndroid Build Coastguard Worker ConstTensor(inputTensorInfo, it.second.data())});
68*89c4ff92SAndroid Build Coastguard Worker }
69*89c4ff92SAndroid Build Coastguard Worker
70*89c4ff92SAndroid Build Coastguard Worker outputTensors.reserve(expectedOutputData.size());
71*89c4ff92SAndroid Build Coastguard Worker for (auto&& it : expectedOutputData[i])
72*89c4ff92SAndroid Build Coastguard Worker {
73*89c4ff92SAndroid Build Coastguard Worker std::vector<TOutput> out(it.second.size());
74*89c4ff92SAndroid Build Coastguard Worker outputStorageVec[i].emplace(it.first, out);
75*89c4ff92SAndroid Build Coastguard Worker outputTensors.push_back({it.first,
76*89c4ff92SAndroid Build Coastguard Worker Tensor(runtime->GetOutputTensorInfo(networkId, it.first),
77*89c4ff92SAndroid Build Coastguard Worker outputStorageVec[i].at(it.first).data())});
78*89c4ff92SAndroid Build Coastguard Worker }
79*89c4ff92SAndroid Build Coastguard Worker
80*89c4ff92SAndroid Build Coastguard Worker inputTensorsVec.push_back(inputTensors);
81*89c4ff92SAndroid Build Coastguard Worker outputTensorsVec.push_back(outputTensors);
82*89c4ff92SAndroid Build Coastguard Worker
83*89c4ff92SAndroid Build Coastguard Worker workingMemHandles.push_back(runtime->CreateWorkingMemHandle(networkId));
84*89c4ff92SAndroid Build Coastguard Worker }
85*89c4ff92SAndroid Build Coastguard Worker
86*89c4ff92SAndroid Build Coastguard Worker std::vector<std::thread> threads;
87*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < numberOfInferences; ++i)
88*89c4ff92SAndroid Build Coastguard Worker {
89*89c4ff92SAndroid Build Coastguard Worker // Access the vectors before we do anything multi-threaded
90*89c4ff92SAndroid Build Coastguard Worker InputTensors& inputTensors = inputTensorsVec[i];
91*89c4ff92SAndroid Build Coastguard Worker OutputTensors& outputTensors = outputTensorsVec[i];
92*89c4ff92SAndroid Build Coastguard Worker IWorkingMemHandle& workingMemHandle = *workingMemHandles[i].get();
93*89c4ff92SAndroid Build Coastguard Worker
94*89c4ff92SAndroid Build Coastguard Worker threads.emplace_back([&]()
95*89c4ff92SAndroid Build Coastguard Worker {
96*89c4ff92SAndroid Build Coastguard Worker // Run the async network
97*89c4ff92SAndroid Build Coastguard Worker runtime->Execute(workingMemHandle, inputTensors, outputTensors);
98*89c4ff92SAndroid Build Coastguard Worker });
99*89c4ff92SAndroid Build Coastguard Worker }
100*89c4ff92SAndroid Build Coastguard Worker
101*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < numberOfInferences; ++i)
102*89c4ff92SAndroid Build Coastguard Worker {
103*89c4ff92SAndroid Build Coastguard Worker threads[i].join();
104*89c4ff92SAndroid Build Coastguard Worker }
105*89c4ff92SAndroid Build Coastguard Worker
106*89c4ff92SAndroid Build Coastguard Worker // Checks the results.
107*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < numberOfInferences; ++i)
108*89c4ff92SAndroid Build Coastguard Worker {
109*89c4ff92SAndroid Build Coastguard Worker for (auto &&it : expectedOutputData[i])
110*89c4ff92SAndroid Build Coastguard Worker {
111*89c4ff92SAndroid Build Coastguard Worker std::vector<TOutput> out = outputStorageVec[i].at(it.first);
112*89c4ff92SAndroid Build Coastguard Worker for (unsigned int j = 0; j < out.size(); ++j)
113*89c4ff92SAndroid Build Coastguard Worker {
114*89c4ff92SAndroid Build Coastguard Worker CHECK(Compare<ArmnnOType>(it.second[j], out[j], tolerance) == true);
115*89c4ff92SAndroid Build Coastguard Worker }
116*89c4ff92SAndroid Build Coastguard Worker }
117*89c4ff92SAndroid Build Coastguard Worker }
118*89c4ff92SAndroid Build Coastguard Worker
119*89c4ff92SAndroid Build Coastguard Worker }
120*89c4ff92SAndroid Build Coastguard Worker
121*89c4ff92SAndroid Build Coastguard Worker template<DataType ArmnnIType, DataType ArmnnOType,
122*89c4ff92SAndroid Build Coastguard Worker typename TInput = ResolveType<ArmnnIType>, typename TOutput = ResolveType<ArmnnOType>>
AsyncEndToEndTestImpl(INetworkPtr network,const std::map<int,std::vector<TInput>> & inputTensorData,const std::map<int,std::vector<TOutput>> & expectedOutputData,std::vector<BackendId> backends,float tolerance=0.000001f,size_t numThreads=1)123*89c4ff92SAndroid Build Coastguard Worker void AsyncEndToEndTestImpl(INetworkPtr network,
124*89c4ff92SAndroid Build Coastguard Worker const std::map<int, std::vector<TInput>>& inputTensorData,
125*89c4ff92SAndroid Build Coastguard Worker const std::map<int, std::vector<TOutput>>& expectedOutputData,
126*89c4ff92SAndroid Build Coastguard Worker std::vector<BackendId> backends,
127*89c4ff92SAndroid Build Coastguard Worker float tolerance = 0.000001f,
128*89c4ff92SAndroid Build Coastguard Worker size_t numThreads = 1)
129*89c4ff92SAndroid Build Coastguard Worker {
130*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(numThreads >= 1);
131*89c4ff92SAndroid Build Coastguard Worker const unsigned int numberOfInferences = numThreads == 1 ? 1 : 1000;
132*89c4ff92SAndroid Build Coastguard Worker
133*89c4ff92SAndroid Build Coastguard Worker // Create Runtime in which test will run
134*89c4ff92SAndroid Build Coastguard Worker IRuntime::CreationOptions options;
135*89c4ff92SAndroid Build Coastguard Worker IRuntimePtr runtime(IRuntime::Create(options));
136*89c4ff92SAndroid Build Coastguard Worker
137*89c4ff92SAndroid Build Coastguard Worker // Optimize the Network
138*89c4ff92SAndroid Build Coastguard Worker IOptimizedNetworkPtr optNet = Optimize(*network, backends, runtime->GetDeviceSpec());
139*89c4ff92SAndroid Build Coastguard Worker
140*89c4ff92SAndroid Build Coastguard Worker // Creates AsyncNetwork
141*89c4ff92SAndroid Build Coastguard Worker NetworkId networkId = 0;
142*89c4ff92SAndroid Build Coastguard Worker
143*89c4ff92SAndroid Build Coastguard Worker std::string errorMessage;
144*89c4ff92SAndroid Build Coastguard Worker
145*89c4ff92SAndroid Build Coastguard Worker const INetworkProperties networkProperties(true, MemorySource::Undefined, MemorySource::Undefined);
146*89c4ff92SAndroid Build Coastguard Worker
147*89c4ff92SAndroid Build Coastguard Worker runtime->LoadNetwork(networkId, std::move(optNet), errorMessage, networkProperties);
148*89c4ff92SAndroid Build Coastguard Worker
149*89c4ff92SAndroid Build Coastguard Worker InputTensors inputTensors;
150*89c4ff92SAndroid Build Coastguard Worker inputTensors.reserve(inputTensorData.size());
151*89c4ff92SAndroid Build Coastguard Worker for (auto&& it : inputTensorData)
152*89c4ff92SAndroid Build Coastguard Worker {
153*89c4ff92SAndroid Build Coastguard Worker TensorInfo inputTensorInfo = runtime->GetInputTensorInfo(networkId, it.first);
154*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo.SetConstant(true);
155*89c4ff92SAndroid Build Coastguard Worker inputTensors.push_back({it.first,
156*89c4ff92SAndroid Build Coastguard Worker ConstTensor(inputTensorInfo, it.second.data())});
157*89c4ff92SAndroid Build Coastguard Worker }
158*89c4ff92SAndroid Build Coastguard Worker
159*89c4ff92SAndroid Build Coastguard Worker std::vector<OutputTensors> outputTensorsVec;
160*89c4ff92SAndroid Build Coastguard Worker std::vector<std::map<int, std::vector<TOutput>>> outputStorageVec;
161*89c4ff92SAndroid Build Coastguard Worker
162*89c4ff92SAndroid Build Coastguard Worker outputTensorsVec.reserve(numberOfInferences);
163*89c4ff92SAndroid Build Coastguard Worker outputStorageVec.reserve(numberOfInferences);
164*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < numberOfInferences; ++i)
165*89c4ff92SAndroid Build Coastguard Worker {
166*89c4ff92SAndroid Build Coastguard Worker OutputTensors outputTensors;
167*89c4ff92SAndroid Build Coastguard Worker outputStorageVec.emplace_back(std::map<int, std::vector<TOutput>>());
168*89c4ff92SAndroid Build Coastguard Worker
169*89c4ff92SAndroid Build Coastguard Worker outputTensors.reserve(expectedOutputData.size());
170*89c4ff92SAndroid Build Coastguard Worker for (auto&& it : expectedOutputData)
171*89c4ff92SAndroid Build Coastguard Worker {
172*89c4ff92SAndroid Build Coastguard Worker std::vector<TOutput> out(it.second.size());
173*89c4ff92SAndroid Build Coastguard Worker outputStorageVec[i].emplace(it.first, out);
174*89c4ff92SAndroid Build Coastguard Worker outputTensors.push_back({it.first,
175*89c4ff92SAndroid Build Coastguard Worker Tensor(runtime->GetOutputTensorInfo(networkId, it.first),
176*89c4ff92SAndroid Build Coastguard Worker outputStorageVec[i].at(it.first).data())});
177*89c4ff92SAndroid Build Coastguard Worker }
178*89c4ff92SAndroid Build Coastguard Worker
179*89c4ff92SAndroid Build Coastguard Worker outputTensorsVec.push_back(outputTensors);
180*89c4ff92SAndroid Build Coastguard Worker }
181*89c4ff92SAndroid Build Coastguard Worker
182*89c4ff92SAndroid Build Coastguard Worker if (numThreads == 1)
183*89c4ff92SAndroid Build Coastguard Worker {
184*89c4ff92SAndroid Build Coastguard Worker // Create WorkingMemHandle for this async network
185*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<IWorkingMemHandle> workingMemHandle = runtime->CreateWorkingMemHandle(networkId);
186*89c4ff92SAndroid Build Coastguard Worker IWorkingMemHandle& workingMemHandleRef = *workingMemHandle.get();
187*89c4ff92SAndroid Build Coastguard Worker
188*89c4ff92SAndroid Build Coastguard Worker // Run the async network
189*89c4ff92SAndroid Build Coastguard Worker runtime->Execute(workingMemHandleRef, inputTensors, outputTensorsVec[0]);
190*89c4ff92SAndroid Build Coastguard Worker }
191*89c4ff92SAndroid Build Coastguard Worker else
192*89c4ff92SAndroid Build Coastguard Worker {
193*89c4ff92SAndroid Build Coastguard Worker std::vector<std::shared_ptr<IWorkingMemHandle>> memHandles;
194*89c4ff92SAndroid Build Coastguard Worker
195*89c4ff92SAndroid Build Coastguard Worker for (size_t i = 0; i < numThreads; ++i)
196*89c4ff92SAndroid Build Coastguard Worker {
197*89c4ff92SAndroid Build Coastguard Worker memHandles.emplace_back(runtime->CreateWorkingMemHandle(networkId));
198*89c4ff92SAndroid Build Coastguard Worker }
199*89c4ff92SAndroid Build Coastguard Worker
200*89c4ff92SAndroid Build Coastguard Worker Threadpool threadpool(numThreads, runtime.get(), memHandles);
201*89c4ff92SAndroid Build Coastguard Worker AsyncCallbackManager callbackManager;
202*89c4ff92SAndroid Build Coastguard Worker
203*89c4ff92SAndroid Build Coastguard Worker // For the asyncronous execution, we are adding a pool of working memory handles (1 per thread) in the
204*89c4ff92SAndroid Build Coastguard Worker // LoadedNetwork with each scheduled inference having a random priority
205*89c4ff92SAndroid Build Coastguard Worker for (size_t i = 0; i < numberOfInferences; ++i)
206*89c4ff92SAndroid Build Coastguard Worker {
207*89c4ff92SAndroid Build Coastguard Worker threadpool.Schedule(networkId,
208*89c4ff92SAndroid Build Coastguard Worker inputTensors,
209*89c4ff92SAndroid Build Coastguard Worker outputTensorsVec[i],
210*89c4ff92SAndroid Build Coastguard Worker static_cast<QosExecPriority>(rand()%3),
211*89c4ff92SAndroid Build Coastguard Worker callbackManager.GetNewCallback());
212*89c4ff92SAndroid Build Coastguard Worker }
213*89c4ff92SAndroid Build Coastguard Worker
214*89c4ff92SAndroid Build Coastguard Worker // Wait until the execution signals a notify
215*89c4ff92SAndroid Build Coastguard Worker for (size_t i = 0; i < numberOfInferences; ++i)
216*89c4ff92SAndroid Build Coastguard Worker {
217*89c4ff92SAndroid Build Coastguard Worker auto cb = callbackManager.GetNotifiedCallback();
218*89c4ff92SAndroid Build Coastguard Worker
219*89c4ff92SAndroid Build Coastguard Worker // Checks the results.
220*89c4ff92SAndroid Build Coastguard Worker CHECK(cb->GetStatus() == Status::Success);
221*89c4ff92SAndroid Build Coastguard Worker }
222*89c4ff92SAndroid Build Coastguard Worker }
223*89c4ff92SAndroid Build Coastguard Worker
224*89c4ff92SAndroid Build Coastguard Worker for (auto&& outputStorage : outputStorageVec)
225*89c4ff92SAndroid Build Coastguard Worker {
226*89c4ff92SAndroid Build Coastguard Worker for (auto&& it : expectedOutputData)
227*89c4ff92SAndroid Build Coastguard Worker {
228*89c4ff92SAndroid Build Coastguard Worker std::vector<TOutput> out = outputStorage.at(it.first);
229*89c4ff92SAndroid Build Coastguard Worker
230*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < out.size(); ++i)
231*89c4ff92SAndroid Build Coastguard Worker {
232*89c4ff92SAndroid Build Coastguard Worker //CHECK(Compare<ArmnnOType>(it.second[i], out[i], tolerance) == true);
233*89c4ff92SAndroid Build Coastguard Worker CHECK(it.second[i] == doctest::Approx(out[i]).epsilon(tolerance));
234*89c4ff92SAndroid Build Coastguard Worker }
235*89c4ff92SAndroid Build Coastguard Worker }
236*89c4ff92SAndroid Build Coastguard Worker }
237*89c4ff92SAndroid Build Coastguard Worker }
238*89c4ff92SAndroid Build Coastguard Worker
239*89c4ff92SAndroid Build Coastguard Worker template<typename armnn::DataType DataType>
CreateStridedSliceNetwork(const TensorShape & inputShape,const TensorShape & outputShape,const std::vector<int> & beginData,const std::vector<int> & endData,const std::vector<int> & stridesData,int beginMask=0,int endMask=0,int shrinkAxisMask=0,int ellipsisMask=0,int newAxisMask=0,const float qScale=1.0f,const int32_t qOffset=0)240*89c4ff92SAndroid Build Coastguard Worker INetworkPtr CreateStridedSliceNetwork(const TensorShape& inputShape,
241*89c4ff92SAndroid Build Coastguard Worker const TensorShape& outputShape,
242*89c4ff92SAndroid Build Coastguard Worker const std::vector<int>& beginData,
243*89c4ff92SAndroid Build Coastguard Worker const std::vector<int>& endData,
244*89c4ff92SAndroid Build Coastguard Worker const std::vector<int>& stridesData,
245*89c4ff92SAndroid Build Coastguard Worker int beginMask = 0,
246*89c4ff92SAndroid Build Coastguard Worker int endMask = 0,
247*89c4ff92SAndroid Build Coastguard Worker int shrinkAxisMask = 0,
248*89c4ff92SAndroid Build Coastguard Worker int ellipsisMask = 0,
249*89c4ff92SAndroid Build Coastguard Worker int newAxisMask = 0,
250*89c4ff92SAndroid Build Coastguard Worker const float qScale = 1.0f,
251*89c4ff92SAndroid Build Coastguard Worker const int32_t qOffset = 0)
252*89c4ff92SAndroid Build Coastguard Worker {
253*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
254*89c4ff92SAndroid Build Coastguard Worker // Builds up the structure of the network.
255*89c4ff92SAndroid Build Coastguard Worker INetworkPtr net(INetwork::Create());
256*89c4ff92SAndroid Build Coastguard Worker
257*89c4ff92SAndroid Build Coastguard Worker TensorInfo inputTensorInfo(inputShape, DataType, qScale, qOffset);
258*89c4ff92SAndroid Build Coastguard Worker TensorInfo outputTensorInfo(outputShape, DataType, qScale, qOffset);
259*89c4ff92SAndroid Build Coastguard Worker
260*89c4ff92SAndroid Build Coastguard Worker armnn::StridedSliceDescriptor stridedSliceDescriptor;
261*89c4ff92SAndroid Build Coastguard Worker stridedSliceDescriptor.m_Begin = beginData;
262*89c4ff92SAndroid Build Coastguard Worker stridedSliceDescriptor.m_End = endData;
263*89c4ff92SAndroid Build Coastguard Worker stridedSliceDescriptor.m_Stride = stridesData;
264*89c4ff92SAndroid Build Coastguard Worker stridedSliceDescriptor.m_BeginMask = beginMask;
265*89c4ff92SAndroid Build Coastguard Worker stridedSliceDescriptor.m_EndMask = endMask;
266*89c4ff92SAndroid Build Coastguard Worker stridedSliceDescriptor.m_ShrinkAxisMask = shrinkAxisMask;
267*89c4ff92SAndroid Build Coastguard Worker stridedSliceDescriptor.m_EllipsisMask = ellipsisMask;
268*89c4ff92SAndroid Build Coastguard Worker stridedSliceDescriptor.m_NewAxisMask = newAxisMask;
269*89c4ff92SAndroid Build Coastguard Worker
270*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* input = net->AddInputLayer(0, "Input_Layer");
271*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* stridedSlice = net->AddStridedSliceLayer(stridedSliceDescriptor, "splitter");
272*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* output = net->AddOutputLayer(0);
273*89c4ff92SAndroid Build Coastguard Worker
274*89c4ff92SAndroid Build Coastguard Worker Connect(input, stridedSlice, inputTensorInfo, 0, 0);
275*89c4ff92SAndroid Build Coastguard Worker Connect(stridedSlice, output, outputTensorInfo, 0, 0);
276*89c4ff92SAndroid Build Coastguard Worker
277*89c4ff92SAndroid Build Coastguard Worker return net;
278*89c4ff92SAndroid Build Coastguard Worker }
279*89c4ff92SAndroid Build Coastguard Worker
280*89c4ff92SAndroid Build Coastguard Worker template<armnn::DataType ArmnnType>
StridedSlicedEndToEndTest(const std::vector<BackendId> & backends,size_t numThreads)281*89c4ff92SAndroid Build Coastguard Worker void StridedSlicedEndToEndTest(const std::vector<BackendId>& backends, size_t numThreads)
282*89c4ff92SAndroid Build Coastguard Worker {
283*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
284*89c4ff92SAndroid Build Coastguard Worker using T = ResolveType<ArmnnType>;
285*89c4ff92SAndroid Build Coastguard Worker
286*89c4ff92SAndroid Build Coastguard Worker const TensorShape& inputShape = {3, 2, 3, 1};
287*89c4ff92SAndroid Build Coastguard Worker const TensorShape& outputShape = {1, 2, 3, 1};
288*89c4ff92SAndroid Build Coastguard Worker const std::vector<int>& beginData = {1, 0, 0, 0};
289*89c4ff92SAndroid Build Coastguard Worker const std::vector<int>& endData = {2, 2, 3, 1};
290*89c4ff92SAndroid Build Coastguard Worker const std::vector<int>& stridesData = {1, 1, 1, 1};
291*89c4ff92SAndroid Build Coastguard Worker int beginMask = 0;
292*89c4ff92SAndroid Build Coastguard Worker int endMask = 0;
293*89c4ff92SAndroid Build Coastguard Worker int shrinkAxisMask = 0;
294*89c4ff92SAndroid Build Coastguard Worker int ellipsisMask = 0;
295*89c4ff92SAndroid Build Coastguard Worker int newAxisMask = 0;
296*89c4ff92SAndroid Build Coastguard Worker
297*89c4ff92SAndroid Build Coastguard Worker // Builds up the structure of the network
298*89c4ff92SAndroid Build Coastguard Worker INetworkPtr net = CreateStridedSliceNetwork<ArmnnType>(inputShape,
299*89c4ff92SAndroid Build Coastguard Worker outputShape,
300*89c4ff92SAndroid Build Coastguard Worker beginData,
301*89c4ff92SAndroid Build Coastguard Worker endData,
302*89c4ff92SAndroid Build Coastguard Worker stridesData,
303*89c4ff92SAndroid Build Coastguard Worker beginMask,
304*89c4ff92SAndroid Build Coastguard Worker endMask,
305*89c4ff92SAndroid Build Coastguard Worker shrinkAxisMask,
306*89c4ff92SAndroid Build Coastguard Worker ellipsisMask,
307*89c4ff92SAndroid Build Coastguard Worker newAxisMask);
308*89c4ff92SAndroid Build Coastguard Worker
309*89c4ff92SAndroid Build Coastguard Worker CHECK(net);
310*89c4ff92SAndroid Build Coastguard Worker // Creates structures for input & output.
311*89c4ff92SAndroid Build Coastguard Worker std::vector<T> inputData{
312*89c4ff92SAndroid Build Coastguard Worker 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
313*89c4ff92SAndroid Build Coastguard Worker
314*89c4ff92SAndroid Build Coastguard Worker 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
315*89c4ff92SAndroid Build Coastguard Worker
316*89c4ff92SAndroid Build Coastguard Worker 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f
317*89c4ff92SAndroid Build Coastguard Worker };
318*89c4ff92SAndroid Build Coastguard Worker
319*89c4ff92SAndroid Build Coastguard Worker std::vector<T> outputExpected{
320*89c4ff92SAndroid Build Coastguard Worker 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f
321*89c4ff92SAndroid Build Coastguard Worker };
322*89c4ff92SAndroid Build Coastguard Worker
323*89c4ff92SAndroid Build Coastguard Worker std::map<int, std::vector<T>> inputTensorData = {{0, inputData}};
324*89c4ff92SAndroid Build Coastguard Worker std::map<int, std::vector<T>> expectedOutputData = {{0, outputExpected}};
325*89c4ff92SAndroid Build Coastguard Worker
326*89c4ff92SAndroid Build Coastguard Worker AsyncEndToEndTestImpl<ArmnnType, ArmnnType>(move(net),
327*89c4ff92SAndroid Build Coastguard Worker inputTensorData,
328*89c4ff92SAndroid Build Coastguard Worker expectedOutputData,
329*89c4ff92SAndroid Build Coastguard Worker backends,
330*89c4ff92SAndroid Build Coastguard Worker 0.000001f,
331*89c4ff92SAndroid Build Coastguard Worker numThreads);
332*89c4ff92SAndroid Build Coastguard Worker }
333*89c4ff92SAndroid Build Coastguard Worker
334*89c4ff92SAndroid Build Coastguard Worker template<armnn::DataType ArmnnType>
StridedSlicedMultiThreadedEndToEndTest(const std::vector<BackendId> & backends)335*89c4ff92SAndroid Build Coastguard Worker void StridedSlicedMultiThreadedEndToEndTest(const std::vector<BackendId>& backends)
336*89c4ff92SAndroid Build Coastguard Worker {
337*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
338*89c4ff92SAndroid Build Coastguard Worker using T = ResolveType<ArmnnType>;
339*89c4ff92SAndroid Build Coastguard Worker
340*89c4ff92SAndroid Build Coastguard Worker const TensorShape& inputShape = {3, 2, 3, 1};
341*89c4ff92SAndroid Build Coastguard Worker const TensorShape& outputShape = {1, 2, 3, 1};
342*89c4ff92SAndroid Build Coastguard Worker const std::vector<int>& beginData = {1, 0, 0, 0};
343*89c4ff92SAndroid Build Coastguard Worker const std::vector<int>& endData = {2, 2, 3, 1};
344*89c4ff92SAndroid Build Coastguard Worker const std::vector<int>& stridesData = {1, 1, 1, 1};
345*89c4ff92SAndroid Build Coastguard Worker int beginMask = 0;
346*89c4ff92SAndroid Build Coastguard Worker int endMask = 0;
347*89c4ff92SAndroid Build Coastguard Worker int shrinkAxisMask = 0;
348*89c4ff92SAndroid Build Coastguard Worker int ellipsisMask = 0;
349*89c4ff92SAndroid Build Coastguard Worker int newAxisMask = 0;
350*89c4ff92SAndroid Build Coastguard Worker
351*89c4ff92SAndroid Build Coastguard Worker // Builds up the structure of the network
352*89c4ff92SAndroid Build Coastguard Worker INetworkPtr net = CreateStridedSliceNetwork<ArmnnType>(inputShape,
353*89c4ff92SAndroid Build Coastguard Worker outputShape,
354*89c4ff92SAndroid Build Coastguard Worker beginData,
355*89c4ff92SAndroid Build Coastguard Worker endData,
356*89c4ff92SAndroid Build Coastguard Worker stridesData,
357*89c4ff92SAndroid Build Coastguard Worker beginMask,
358*89c4ff92SAndroid Build Coastguard Worker endMask,
359*89c4ff92SAndroid Build Coastguard Worker shrinkAxisMask,
360*89c4ff92SAndroid Build Coastguard Worker ellipsisMask,
361*89c4ff92SAndroid Build Coastguard Worker newAxisMask);
362*89c4ff92SAndroid Build Coastguard Worker
363*89c4ff92SAndroid Build Coastguard Worker CHECK(net);
364*89c4ff92SAndroid Build Coastguard Worker
365*89c4ff92SAndroid Build Coastguard Worker // Creates structures for input & output.
366*89c4ff92SAndroid Build Coastguard Worker std::vector<T> inputData1{
367*89c4ff92SAndroid Build Coastguard Worker 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
368*89c4ff92SAndroid Build Coastguard Worker
369*89c4ff92SAndroid Build Coastguard Worker 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
370*89c4ff92SAndroid Build Coastguard Worker
371*89c4ff92SAndroid Build Coastguard Worker 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f
372*89c4ff92SAndroid Build Coastguard Worker };
373*89c4ff92SAndroid Build Coastguard Worker
374*89c4ff92SAndroid Build Coastguard Worker std::vector<T> outputExpected1{ 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f };
375*89c4ff92SAndroid Build Coastguard Worker
376*89c4ff92SAndroid Build Coastguard Worker // Creates structures for input & output.
377*89c4ff92SAndroid Build Coastguard Worker std::vector<T> inputData2{
378*89c4ff92SAndroid Build Coastguard Worker 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
379*89c4ff92SAndroid Build Coastguard Worker
380*89c4ff92SAndroid Build Coastguard Worker 8.0f, 8.0f, 8.0f, 7.0f, 7.0f, 7.0f,
381*89c4ff92SAndroid Build Coastguard Worker
382*89c4ff92SAndroid Build Coastguard Worker 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f
383*89c4ff92SAndroid Build Coastguard Worker };
384*89c4ff92SAndroid Build Coastguard Worker
385*89c4ff92SAndroid Build Coastguard Worker std::vector<T> outputExpected2{ 8.0f, 8.0f, 8.0f, 7.0f, 7.0f, 7.0f };
386*89c4ff92SAndroid Build Coastguard Worker
387*89c4ff92SAndroid Build Coastguard Worker std::vector<std::map<int, std::vector<T>>> inputTensors;
388*89c4ff92SAndroid Build Coastguard Worker std::vector<std::map<int, std::vector<T>>> outputTensors;
389*89c4ff92SAndroid Build Coastguard Worker
390*89c4ff92SAndroid Build Coastguard Worker inputTensors.push_back(std::map<int, std::vector<T>> {{0, inputData1}});
391*89c4ff92SAndroid Build Coastguard Worker inputTensors.push_back(std::map<int, std::vector<T>> {{0, inputData2}});
392*89c4ff92SAndroid Build Coastguard Worker outputTensors.push_back(std::map<int, std::vector<T>> {{0, outputExpected1}});
393*89c4ff92SAndroid Build Coastguard Worker outputTensors.push_back(std::map<int, std::vector<T>> {{0, outputExpected2}});
394*89c4ff92SAndroid Build Coastguard Worker
395*89c4ff92SAndroid Build Coastguard Worker AsyncThreadedEndToEndTestImpl<ArmnnType, ArmnnType>(move(net), inputTensors, outputTensors, backends, 2);
396*89c4ff92SAndroid Build Coastguard Worker }
397*89c4ff92SAndroid Build Coastguard Worker
398*89c4ff92SAndroid Build Coastguard Worker } // experimental namespace
399*89c4ff92SAndroid Build Coastguard Worker
400*89c4ff92SAndroid Build Coastguard Worker } // armnn namespace
401*89c4ff92SAndroid Build Coastguard Worker
402