1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017,2021-2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #pragma once
6*89c4ff92SAndroid Build Coastguard Worker
7*89c4ff92SAndroid Build Coastguard Worker #include "TestUtils.hpp"
8*89c4ff92SAndroid Build Coastguard Worker
9*89c4ff92SAndroid Build Coastguard Worker #include <Graph.hpp>
10*89c4ff92SAndroid Build Coastguard Worker #include <Network.hpp>
11*89c4ff92SAndroid Build Coastguard Worker #include <ResolveType.hpp>
12*89c4ff92SAndroid Build Coastguard Worker
13*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/DataLayoutIndexed.hpp>
14*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/TensorHandle.hpp>
15*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/WorkloadData.hpp>
16*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/WorkloadFactory.hpp>
17*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/Assert.hpp>
18*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/IgnoreUnused.hpp>
19*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/PolymorphicDowncast.hpp>
20*89c4ff92SAndroid Build Coastguard Worker
21*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
22*89c4ff92SAndroid Build Coastguard Worker
23*89c4ff92SAndroid Build Coastguard Worker #include <utility>
24*89c4ff92SAndroid Build Coastguard Worker
25*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
26*89c4ff92SAndroid Build Coastguard Worker
27*89c4ff92SAndroid Build Coastguard Worker namespace
28*89c4ff92SAndroid Build Coastguard Worker {
29*89c4ff92SAndroid Build Coastguard Worker
30*89c4ff92SAndroid Build Coastguard Worker using namespace std;
31*89c4ff92SAndroid Build Coastguard Worker
32*89c4ff92SAndroid Build Coastguard Worker // Calls CreateWorkload for a layer, and checks the returned pointer is of the correct type.
33*89c4ff92SAndroid Build Coastguard Worker template<typename Workload>
MakeAndCheckWorkload(Layer & layer,const IWorkloadFactory & factory,const ModelOptions & modelOptions={})34*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<Workload> MakeAndCheckWorkload(Layer& layer,
35*89c4ff92SAndroid Build Coastguard Worker const IWorkloadFactory& factory,
36*89c4ff92SAndroid Build Coastguard Worker const ModelOptions& modelOptions = {})
37*89c4ff92SAndroid Build Coastguard Worker {
38*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<IWorkload> workload = layer.CreateWorkload(factory);
39*89c4ff92SAndroid Build Coastguard Worker CHECK_MESSAGE(workload.get() == PolymorphicDowncast<Workload*>(workload.get()),
40*89c4ff92SAndroid Build Coastguard Worker "Cannot convert to derived class");
41*89c4ff92SAndroid Build Coastguard Worker std::string reasonIfUnsupported;
42*89c4ff92SAndroid Build Coastguard Worker layer.SetBackendId(factory.GetBackendId());
43*89c4ff92SAndroid Build Coastguard Worker CHECK(factory.IsLayerSupported(layer, layer.GetDataType(), reasonIfUnsupported, modelOptions));
44*89c4ff92SAndroid Build Coastguard Worker return std::unique_ptr<Workload>(static_cast<Workload*>(workload.release()));
45*89c4ff92SAndroid Build Coastguard Worker }
46*89c4ff92SAndroid Build Coastguard Worker
47*89c4ff92SAndroid Build Coastguard Worker // Helper function to create tensor handlers for workloads, assuming they all use the same factory.
CreateTensorHandles(armnn::Graph & graph,armnn::IWorkloadFactory & factory)48*89c4ff92SAndroid Build Coastguard Worker void CreateTensorHandles(armnn::Graph& graph,
49*89c4ff92SAndroid Build Coastguard Worker armnn::IWorkloadFactory& factory)
50*89c4ff92SAndroid Build Coastguard Worker {
51*89c4ff92SAndroid Build Coastguard Worker TensorHandleFactoryRegistry tmpRegistry;
52*89c4ff92SAndroid Build Coastguard Worker for (auto&& layer : graph.TopologicalSort())
53*89c4ff92SAndroid Build Coastguard Worker {
54*89c4ff92SAndroid Build Coastguard Worker layer->CreateTensorHandles(tmpRegistry, factory);
55*89c4ff92SAndroid Build Coastguard Worker }
56*89c4ff92SAndroid Build Coastguard Worker }
57*89c4ff92SAndroid Build Coastguard Worker
58*89c4ff92SAndroid Build Coastguard Worker /////////////////////////////////////////////////////////////////////////////////////////////
59*89c4ff92SAndroid Build Coastguard Worker // The following functions are called by backendsCommon/test/CreateWorkload*.cpp
60*89c4ff92SAndroid Build Coastguard Worker // They build very simple graphs, and then create a workload.
61*89c4ff92SAndroid Build Coastguard Worker // Some checks are performed on the workload to ensure parameters have been passed correctly.
62*89c4ff92SAndroid Build Coastguard Worker // They return the created workloads so that backend-specific checks can be performed.
63*89c4ff92SAndroid Build Coastguard Worker /////////////////////////////////////////////////////////////////////////////////////////////
64*89c4ff92SAndroid Build Coastguard Worker
65*89c4ff92SAndroid Build Coastguard Worker template <typename ActivationWorkload, armnn::DataType DataType>
CreateActivationWorkloadTest(armnn::IWorkloadFactory & factory,armnn::Graph & graph)66*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<ActivationWorkload> CreateActivationWorkloadTest(armnn::IWorkloadFactory& factory,
67*89c4ff92SAndroid Build Coastguard Worker armnn::Graph& graph)
68*89c4ff92SAndroid Build Coastguard Worker {
69*89c4ff92SAndroid Build Coastguard Worker // Creates the layer we're testing.
70*89c4ff92SAndroid Build Coastguard Worker ActivationDescriptor layerDesc;
71*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_Function = ActivationFunction::ReLu;
72*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_A = 3.5f;
73*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_B = -10.0f;
74*89c4ff92SAndroid Build Coastguard Worker
75*89c4ff92SAndroid Build Coastguard Worker ActivationLayer* const layer = graph.AddLayer<ActivationLayer>(layerDesc, "layer");
76*89c4ff92SAndroid Build Coastguard Worker
77*89c4ff92SAndroid Build Coastguard Worker // Creates extra layers.
78*89c4ff92SAndroid Build Coastguard Worker Layer* const input = graph.AddLayer<InputLayer>(0, "input");
79*89c4ff92SAndroid Build Coastguard Worker Layer* const output = graph.AddLayer<OutputLayer>(0, "output");
80*89c4ff92SAndroid Build Coastguard Worker
81*89c4ff92SAndroid Build Coastguard Worker // Connects up.
82*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo tensorInfo({1, 1}, DataType);
83*89c4ff92SAndroid Build Coastguard Worker
84*89c4ff92SAndroid Build Coastguard Worker Connect(input, layer, tensorInfo);
85*89c4ff92SAndroid Build Coastguard Worker Connect(layer, output, tensorInfo);
86*89c4ff92SAndroid Build Coastguard Worker
87*89c4ff92SAndroid Build Coastguard Worker CreateTensorHandles(graph, factory);
88*89c4ff92SAndroid Build Coastguard Worker
89*89c4ff92SAndroid Build Coastguard Worker // Makes the workload and checks it.
90*89c4ff92SAndroid Build Coastguard Worker auto workload = MakeAndCheckWorkload<ActivationWorkload>(*layer, factory);
91*89c4ff92SAndroid Build Coastguard Worker
92*89c4ff92SAndroid Build Coastguard Worker ActivationQueueDescriptor queueDescriptor = workload->GetData();
93*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Inputs.size() == 1);
94*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Outputs.size() == 1);
95*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Parameters.m_A == 3.5f);
96*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Parameters.m_B == -10.0f);
97*89c4ff92SAndroid Build Coastguard Worker CHECK((queueDescriptor.m_Parameters.m_Function == ActivationFunction::ReLu));
98*89c4ff92SAndroid Build Coastguard Worker
99*89c4ff92SAndroid Build Coastguard Worker // Returns so we can do extra, backend-specific tests.
100*89c4ff92SAndroid Build Coastguard Worker return workload;
101*89c4ff92SAndroid Build Coastguard Worker }
102*89c4ff92SAndroid Build Coastguard Worker
103*89c4ff92SAndroid Build Coastguard Worker template <typename WorkloadType,
104*89c4ff92SAndroid Build Coastguard Worker typename DescriptorType,
105*89c4ff92SAndroid Build Coastguard Worker typename LayerType,
106*89c4ff92SAndroid Build Coastguard Worker armnn::DataType DataType>
CreateElementwiseWorkloadTest(armnn::IWorkloadFactory & factory,armnn::Graph & graph)107*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<WorkloadType> CreateElementwiseWorkloadTest(armnn::IWorkloadFactory & factory,
108*89c4ff92SAndroid Build Coastguard Worker armnn::Graph & graph)
109*89c4ff92SAndroid Build Coastguard Worker {
110*89c4ff92SAndroid Build Coastguard Worker // Creates the layer we're testing.
111*89c4ff92SAndroid Build Coastguard Worker Layer* const layer = graph.AddLayer<LayerType>("layer");
112*89c4ff92SAndroid Build Coastguard Worker
113*89c4ff92SAndroid Build Coastguard Worker // Creates extra layers.
114*89c4ff92SAndroid Build Coastguard Worker Layer* const input1 = graph.AddLayer<InputLayer>(1, "input1");
115*89c4ff92SAndroid Build Coastguard Worker Layer* const input2 = graph.AddLayer<InputLayer>(2, "input2");
116*89c4ff92SAndroid Build Coastguard Worker Layer* const output = graph.AddLayer<OutputLayer>(0, "output");
117*89c4ff92SAndroid Build Coastguard Worker
118*89c4ff92SAndroid Build Coastguard Worker // Connects up.
119*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo tensorInfo({2, 3}, DataType);
120*89c4ff92SAndroid Build Coastguard Worker Connect(input1, layer, tensorInfo, 0, 0);
121*89c4ff92SAndroid Build Coastguard Worker Connect(input2, layer, tensorInfo, 0, 1);
122*89c4ff92SAndroid Build Coastguard Worker Connect(layer, output, tensorInfo);
123*89c4ff92SAndroid Build Coastguard Worker CreateTensorHandles(graph, factory);
124*89c4ff92SAndroid Build Coastguard Worker
125*89c4ff92SAndroid Build Coastguard Worker // Makes the workload and checks it.
126*89c4ff92SAndroid Build Coastguard Worker auto workload = MakeAndCheckWorkload<WorkloadType>(*layer, factory);
127*89c4ff92SAndroid Build Coastguard Worker
128*89c4ff92SAndroid Build Coastguard Worker auto queueDescriptor = workload->GetData();
129*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Inputs.size() == 2);
130*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Outputs.size() == 1);
131*89c4ff92SAndroid Build Coastguard Worker
132*89c4ff92SAndroid Build Coastguard Worker // Returns so we can do extra, backend-specific tests.
133*89c4ff92SAndroid Build Coastguard Worker return workload;
134*89c4ff92SAndroid Build Coastguard Worker }
135*89c4ff92SAndroid Build Coastguard Worker
136*89c4ff92SAndroid Build Coastguard Worker template <typename WorkloadType, armnn::DataType DataType>
CreateElementwiseBinaryWorkloadTest(armnn::IWorkloadFactory & factory,armnn::Graph & graph,armnn::BinaryOperation binaryOperation)137*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<WorkloadType> CreateElementwiseBinaryWorkloadTest(armnn::IWorkloadFactory & factory,
138*89c4ff92SAndroid Build Coastguard Worker armnn::Graph & graph,
139*89c4ff92SAndroid Build Coastguard Worker armnn::BinaryOperation binaryOperation)
140*89c4ff92SAndroid Build Coastguard Worker {
141*89c4ff92SAndroid Build Coastguard Worker // Creates the layer we're testing.
142*89c4ff92SAndroid Build Coastguard Worker ElementwiseBinaryDescriptor descriptor(binaryOperation);
143*89c4ff92SAndroid Build Coastguard Worker //ElementwiseBinaryDescriptor descriptor = ElementwiseBinaryDescriptor(binaryOperation);
144*89c4ff92SAndroid Build Coastguard Worker
145*89c4ff92SAndroid Build Coastguard Worker Layer* const layer = graph.AddLayer<ElementwiseBinaryLayer>(descriptor, "layer");
146*89c4ff92SAndroid Build Coastguard Worker
147*89c4ff92SAndroid Build Coastguard Worker // Creates extra layers.
148*89c4ff92SAndroid Build Coastguard Worker Layer* const input1 = graph.AddLayer<InputLayer>(1, "input1");
149*89c4ff92SAndroid Build Coastguard Worker Layer* const input2 = graph.AddLayer<InputLayer>(2, "input2");
150*89c4ff92SAndroid Build Coastguard Worker Layer* const output = graph.AddLayer<OutputLayer>(0, "output");
151*89c4ff92SAndroid Build Coastguard Worker
152*89c4ff92SAndroid Build Coastguard Worker // Connects up.
153*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo tensorInfo({2, 3}, DataType);
154*89c4ff92SAndroid Build Coastguard Worker Connect(input1, layer, tensorInfo, 0, 0);
155*89c4ff92SAndroid Build Coastguard Worker Connect(input2, layer, tensorInfo, 0, 1);
156*89c4ff92SAndroid Build Coastguard Worker Connect(layer, output, tensorInfo);
157*89c4ff92SAndroid Build Coastguard Worker CreateTensorHandles(graph, factory);
158*89c4ff92SAndroid Build Coastguard Worker
159*89c4ff92SAndroid Build Coastguard Worker // Makes the workload and checks it.
160*89c4ff92SAndroid Build Coastguard Worker auto workload = MakeAndCheckWorkload<WorkloadType>(*layer, factory);
161*89c4ff92SAndroid Build Coastguard Worker
162*89c4ff92SAndroid Build Coastguard Worker auto queueDescriptor = workload->GetData();
163*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Inputs.size() == 2);
164*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Outputs.size() == 1);
165*89c4ff92SAndroid Build Coastguard Worker
166*89c4ff92SAndroid Build Coastguard Worker // Returns so we can do extra, backend-specific tests.
167*89c4ff92SAndroid Build Coastguard Worker return workload;
168*89c4ff92SAndroid Build Coastguard Worker }
169*89c4ff92SAndroid Build Coastguard Worker
170*89c4ff92SAndroid Build Coastguard Worker template<typename WorkloadType,
171*89c4ff92SAndroid Build Coastguard Worker typename DescriptorType,
172*89c4ff92SAndroid Build Coastguard Worker armnn::DataType DataType>
CreateSubtractionWithBlobWorkloadTest(armnn::IWorkloadFactory & factory,armnn::Graph & graph)173*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<WorkloadType> CreateSubtractionWithBlobWorkloadTest(armnn::IWorkloadFactory& factory,
174*89c4ff92SAndroid Build Coastguard Worker armnn::Graph& graph)
175*89c4ff92SAndroid Build Coastguard Worker {
176*89c4ff92SAndroid Build Coastguard Worker // Creates the layer we're testing.
177*89c4ff92SAndroid Build Coastguard Worker SubtractionLayer* const layer = graph.AddLayer<SubtractionLayer>("layer");
178*89c4ff92SAndroid Build Coastguard Worker
179*89c4ff92SAndroid Build Coastguard Worker auto activationDesc = std::make_shared<ActivationDescriptor>();
180*89c4ff92SAndroid Build Coastguard Worker activationDesc->m_A = 10.0f;
181*89c4ff92SAndroid Build Coastguard Worker activationDesc->m_B = 5.0f;
182*89c4ff92SAndroid Build Coastguard Worker activationDesc->m_Function = armnn::ActivationFunction::BoundedReLu;
183*89c4ff92SAndroid Build Coastguard Worker
184*89c4ff92SAndroid Build Coastguard Worker layer->SetAdditionalInfoForObject(activationDesc);
185*89c4ff92SAndroid Build Coastguard Worker
186*89c4ff92SAndroid Build Coastguard Worker // Creates extra layers.
187*89c4ff92SAndroid Build Coastguard Worker Layer* const input1 = graph.AddLayer<InputLayer>(1, "input1");
188*89c4ff92SAndroid Build Coastguard Worker Layer* const input2 = graph.AddLayer<InputLayer>(2, "input2");
189*89c4ff92SAndroid Build Coastguard Worker Layer* const output = graph.AddLayer<OutputLayer>(0, "output");
190*89c4ff92SAndroid Build Coastguard Worker
191*89c4ff92SAndroid Build Coastguard Worker // Connects up.
192*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo tensorInfo({2, 3}, DataType);
193*89c4ff92SAndroid Build Coastguard Worker Connect(input1, layer, tensorInfo, 0, 0);
194*89c4ff92SAndroid Build Coastguard Worker Connect(input2, layer, tensorInfo, 0, 1);
195*89c4ff92SAndroid Build Coastguard Worker Connect(layer, output, tensorInfo);
196*89c4ff92SAndroid Build Coastguard Worker CreateTensorHandles(graph, factory);
197*89c4ff92SAndroid Build Coastguard Worker
198*89c4ff92SAndroid Build Coastguard Worker // Check that the additional information can be queried from the layer
199*89c4ff92SAndroid Build Coastguard Worker std::shared_ptr<ActivationDescriptor>
200*89c4ff92SAndroid Build Coastguard Worker activationDescPtr = layer->GetAdditionalInformation<ActivationDescriptor>();
201*89c4ff92SAndroid Build Coastguard Worker
202*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(static_cast<float>(activationDescPtr->m_A) == 10.0f);
203*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(static_cast<float>(activationDescPtr->m_B) == 5.0f);
204*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(
205*89c4ff92SAndroid Build Coastguard Worker static_cast<ActivationFunction>(activationDescPtr->m_Function) == armnn::ActivationFunction::BoundedReLu
206*89c4ff92SAndroid Build Coastguard Worker );
207*89c4ff92SAndroid Build Coastguard Worker
208*89c4ff92SAndroid Build Coastguard Worker // Makes the workload and checks it.
209*89c4ff92SAndroid Build Coastguard Worker auto workload = MakeAndCheckWorkload<WorkloadType>(*layer, factory);
210*89c4ff92SAndroid Build Coastguard Worker
211*89c4ff92SAndroid Build Coastguard Worker DescriptorType queueDescriptor = workload->GetData();
212*89c4ff92SAndroid Build Coastguard Worker
213*89c4ff92SAndroid Build Coastguard Worker const ActivationDescriptor* queueDescBlobPtr =
214*89c4ff92SAndroid Build Coastguard Worker queueDescriptor.template GetAdditionalInformation<ActivationDescriptor>();
215*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(queueDescBlobPtr);
216*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(static_cast<float>(queueDescBlobPtr->m_A) == 10.0f);
217*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(static_cast<float>(queueDescBlobPtr->m_B) == 5.0f);
218*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(
219*89c4ff92SAndroid Build Coastguard Worker static_cast<ActivationFunction>(queueDescBlobPtr->m_Function) == armnn::ActivationFunction::BoundedReLu
220*89c4ff92SAndroid Build Coastguard Worker );
221*89c4ff92SAndroid Build Coastguard Worker
222*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Inputs.size() == 2);
223*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Outputs.size() == 1);
224*89c4ff92SAndroid Build Coastguard Worker
225*89c4ff92SAndroid Build Coastguard Worker return workload;
226*89c4ff92SAndroid Build Coastguard Worker }
227*89c4ff92SAndroid Build Coastguard Worker
228*89c4ff92SAndroid Build Coastguard Worker
229*89c4ff92SAndroid Build Coastguard Worker template<typename WorkloadType,
230*89c4ff92SAndroid Build Coastguard Worker typename DescriptorType,
231*89c4ff92SAndroid Build Coastguard Worker armnn::DataType DataType>
CreateMultiplicationWithBlobWorkloadTest(armnn::IWorkloadFactory & factory,armnn::Graph & graph)232*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<WorkloadType> CreateMultiplicationWithBlobWorkloadTest(armnn::IWorkloadFactory& factory,
233*89c4ff92SAndroid Build Coastguard Worker armnn::Graph& graph)
234*89c4ff92SAndroid Build Coastguard Worker {
235*89c4ff92SAndroid Build Coastguard Worker // Creates the layer we're testing.
236*89c4ff92SAndroid Build Coastguard Worker MultiplicationLayer* const layer = graph.AddLayer<MultiplicationLayer>("layer");
237*89c4ff92SAndroid Build Coastguard Worker
238*89c4ff92SAndroid Build Coastguard Worker auto activationDesc = std::make_shared<ActivationDescriptor>();
239*89c4ff92SAndroid Build Coastguard Worker activationDesc->m_A = 10.0f;
240*89c4ff92SAndroid Build Coastguard Worker activationDesc->m_B = 5.0f;
241*89c4ff92SAndroid Build Coastguard Worker activationDesc->m_Function = armnn::ActivationFunction::BoundedReLu;
242*89c4ff92SAndroid Build Coastguard Worker
243*89c4ff92SAndroid Build Coastguard Worker layer->SetAdditionalInfoForObject(activationDesc);
244*89c4ff92SAndroid Build Coastguard Worker
245*89c4ff92SAndroid Build Coastguard Worker // Creates extra layers.
246*89c4ff92SAndroid Build Coastguard Worker Layer* const input1 = graph.AddLayer<InputLayer>(1, "input1");
247*89c4ff92SAndroid Build Coastguard Worker Layer* const input2 = graph.AddLayer<InputLayer>(2, "input2");
248*89c4ff92SAndroid Build Coastguard Worker Layer* const output = graph.AddLayer<OutputLayer>(0, "output");
249*89c4ff92SAndroid Build Coastguard Worker
250*89c4ff92SAndroid Build Coastguard Worker // Connects up.
251*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo tensorInfo({2, 3}, DataType);
252*89c4ff92SAndroid Build Coastguard Worker Connect(input1, layer, tensorInfo, 0, 0);
253*89c4ff92SAndroid Build Coastguard Worker Connect(input2, layer, tensorInfo, 0, 1);
254*89c4ff92SAndroid Build Coastguard Worker Connect(layer, output, tensorInfo);
255*89c4ff92SAndroid Build Coastguard Worker CreateTensorHandles(graph, factory);
256*89c4ff92SAndroid Build Coastguard Worker
257*89c4ff92SAndroid Build Coastguard Worker // Check that the additional information can be queried from the layer
258*89c4ff92SAndroid Build Coastguard Worker std::shared_ptr<ActivationDescriptor>
259*89c4ff92SAndroid Build Coastguard Worker activationDescPtr = layer->GetAdditionalInformation<ActivationDescriptor>();
260*89c4ff92SAndroid Build Coastguard Worker
261*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(static_cast<float>(activationDescPtr->m_A) == 10.0f);
262*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(static_cast<float>(activationDescPtr->m_B) == 5.0f);
263*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(
264*89c4ff92SAndroid Build Coastguard Worker static_cast<ActivationFunction>(activationDescPtr->m_Function) == armnn::ActivationFunction::BoundedReLu
265*89c4ff92SAndroid Build Coastguard Worker );
266*89c4ff92SAndroid Build Coastguard Worker
267*89c4ff92SAndroid Build Coastguard Worker // Makes the workload and checks it.
268*89c4ff92SAndroid Build Coastguard Worker auto workload = MakeAndCheckWorkload<WorkloadType>(*layer, factory);
269*89c4ff92SAndroid Build Coastguard Worker
270*89c4ff92SAndroid Build Coastguard Worker DescriptorType queueDescriptor = workload->GetData();
271*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Inputs.size() == 2);
272*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Outputs.size() == 1);
273*89c4ff92SAndroid Build Coastguard Worker const ActivationDescriptor* queueDescBlobPtr =
274*89c4ff92SAndroid Build Coastguard Worker queueDescriptor.template GetAdditionalInformation<ActivationDescriptor>();
275*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(queueDescBlobPtr);
276*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(static_cast<float>(queueDescBlobPtr->m_A) == 10.0f);
277*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(static_cast<float>(queueDescBlobPtr->m_B) == 5.0f);
278*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(
279*89c4ff92SAndroid Build Coastguard Worker static_cast<ActivationFunction>(queueDescBlobPtr->m_Function) == armnn::ActivationFunction::BoundedReLu
280*89c4ff92SAndroid Build Coastguard Worker );
281*89c4ff92SAndroid Build Coastguard Worker
282*89c4ff92SAndroid Build Coastguard Worker return workload;// Returns so we can do extra, backend-specific tests.
283*89c4ff92SAndroid Build Coastguard Worker }
284*89c4ff92SAndroid Build Coastguard Worker
285*89c4ff92SAndroid Build Coastguard Worker template<typename WorkloadType,
286*89c4ff92SAndroid Build Coastguard Worker typename DescriptorType,
287*89c4ff92SAndroid Build Coastguard Worker armnn::DataType DataType>
CreateAdditionWithBlobWorkloadTest(armnn::IWorkloadFactory & factory,armnn::Graph & graph)288*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<WorkloadType> CreateAdditionWithBlobWorkloadTest(armnn::IWorkloadFactory& factory,
289*89c4ff92SAndroid Build Coastguard Worker armnn::Graph& graph)
290*89c4ff92SAndroid Build Coastguard Worker {
291*89c4ff92SAndroid Build Coastguard Worker // Creates the layer we're testing.
292*89c4ff92SAndroid Build Coastguard Worker AdditionLayer* const layer = graph.AddLayer<AdditionLayer>("layer");
293*89c4ff92SAndroid Build Coastguard Worker
294*89c4ff92SAndroid Build Coastguard Worker auto activationDesc = std::make_shared<ActivationDescriptor>();
295*89c4ff92SAndroid Build Coastguard Worker activationDesc->m_A = 10.0f;
296*89c4ff92SAndroid Build Coastguard Worker activationDesc->m_B = 5.0f;
297*89c4ff92SAndroid Build Coastguard Worker activationDesc->m_Function = armnn::ActivationFunction::BoundedReLu;
298*89c4ff92SAndroid Build Coastguard Worker
299*89c4ff92SAndroid Build Coastguard Worker layer->SetAdditionalInfoForObject(activationDesc);
300*89c4ff92SAndroid Build Coastguard Worker
301*89c4ff92SAndroid Build Coastguard Worker // Creates extra layers.
302*89c4ff92SAndroid Build Coastguard Worker Layer* const input1 = graph.AddLayer<InputLayer>(1, "input1");
303*89c4ff92SAndroid Build Coastguard Worker Layer* const input2 = graph.AddLayer<InputLayer>(2, "input2");
304*89c4ff92SAndroid Build Coastguard Worker Layer* const output = graph.AddLayer<OutputLayer>(0, "output");
305*89c4ff92SAndroid Build Coastguard Worker
306*89c4ff92SAndroid Build Coastguard Worker // Connects up.
307*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo tensorInfo({2, 3}, DataType);
308*89c4ff92SAndroid Build Coastguard Worker Connect(input1, layer, tensorInfo, 0, 0);
309*89c4ff92SAndroid Build Coastguard Worker Connect(input2, layer, tensorInfo, 0, 1);
310*89c4ff92SAndroid Build Coastguard Worker Connect(layer, output, tensorInfo);
311*89c4ff92SAndroid Build Coastguard Worker CreateTensorHandles(graph, factory);
312*89c4ff92SAndroid Build Coastguard Worker
313*89c4ff92SAndroid Build Coastguard Worker // Check that the additional information can be queried from the layer
314*89c4ff92SAndroid Build Coastguard Worker std::shared_ptr<ActivationDescriptor>
315*89c4ff92SAndroid Build Coastguard Worker activationDescPtr = layer->template GetAdditionalInformation<ActivationDescriptor>();
316*89c4ff92SAndroid Build Coastguard Worker
317*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(static_cast<float>(activationDescPtr->m_A) == 10.0f);
318*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(static_cast<float>(activationDescPtr->m_B) == 5.0f);
319*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(
320*89c4ff92SAndroid Build Coastguard Worker static_cast<ActivationFunction>(activationDescPtr->m_Function) == armnn::ActivationFunction::BoundedReLu
321*89c4ff92SAndroid Build Coastguard Worker );
322*89c4ff92SAndroid Build Coastguard Worker
323*89c4ff92SAndroid Build Coastguard Worker // Makes the workload and checks it.
324*89c4ff92SAndroid Build Coastguard Worker auto workload = MakeAndCheckWorkload<WorkloadType>(*layer, factory);
325*89c4ff92SAndroid Build Coastguard Worker
326*89c4ff92SAndroid Build Coastguard Worker DescriptorType queueDescriptor = workload->GetData();
327*89c4ff92SAndroid Build Coastguard Worker const ActivationDescriptor* queueDescBlobPtr =
328*89c4ff92SAndroid Build Coastguard Worker queueDescriptor.template GetAdditionalInformation<ActivationDescriptor>();
329*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(queueDescBlobPtr);
330*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Inputs.size() == 2);
331*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Outputs.size() == 1);
332*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(static_cast<float>(queueDescBlobPtr->m_A) == 10.0f);
333*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(static_cast<float>(queueDescBlobPtr->m_B) == 5.0f);
334*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(
335*89c4ff92SAndroid Build Coastguard Worker static_cast<ActivationFunction>(queueDescBlobPtr->m_Function) == armnn::ActivationFunction::BoundedReLu
336*89c4ff92SAndroid Build Coastguard Worker );
337*89c4ff92SAndroid Build Coastguard Worker
338*89c4ff92SAndroid Build Coastguard Worker return workload;
339*89c4ff92SAndroid Build Coastguard Worker }
340*89c4ff92SAndroid Build Coastguard Worker
341*89c4ff92SAndroid Build Coastguard Worker template <typename WorkloadType,
342*89c4ff92SAndroid Build Coastguard Worker typename DescriptorType,
343*89c4ff92SAndroid Build Coastguard Worker armnn::DataType DataType>
CreateElementwiseUnaryWorkloadTest(armnn::IWorkloadFactory & factory,armnn::Graph & graph,armnn::UnaryOperation op)344*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<WorkloadType> CreateElementwiseUnaryWorkloadTest(armnn::IWorkloadFactory & factory,
345*89c4ff92SAndroid Build Coastguard Worker armnn::Graph & graph,
346*89c4ff92SAndroid Build Coastguard Worker armnn::UnaryOperation op)
347*89c4ff92SAndroid Build Coastguard Worker {
348*89c4ff92SAndroid Build Coastguard Worker ElementwiseUnaryDescriptor desc = ElementwiseUnaryDescriptor(op);
349*89c4ff92SAndroid Build Coastguard Worker Layer* const layer = graph.AddLayer<armnn::ElementwiseUnaryLayer>(desc, "layer");
350*89c4ff92SAndroid Build Coastguard Worker
351*89c4ff92SAndroid Build Coastguard Worker Layer* const input = graph.AddLayer<InputLayer>(0, "input");
352*89c4ff92SAndroid Build Coastguard Worker Layer* const output = graph.AddLayer<OutputLayer>(0, "output");
353*89c4ff92SAndroid Build Coastguard Worker
354*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo tensorInfo({ 2, 3 }, DataType);
355*89c4ff92SAndroid Build Coastguard Worker Connect(input, layer, tensorInfo, 0, 0);
356*89c4ff92SAndroid Build Coastguard Worker Connect(layer, output, tensorInfo, 0, 0);
357*89c4ff92SAndroid Build Coastguard Worker CreateTensorHandles(graph, factory);
358*89c4ff92SAndroid Build Coastguard Worker
359*89c4ff92SAndroid Build Coastguard Worker auto workload = MakeAndCheckWorkload<WorkloadType>(*layer, factory);
360*89c4ff92SAndroid Build Coastguard Worker DescriptorType queueDescriptor = workload->GetData();
361*89c4ff92SAndroid Build Coastguard Worker
362*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Inputs.size() == 1);
363*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Outputs.size() == 1);
364*89c4ff92SAndroid Build Coastguard Worker
365*89c4ff92SAndroid Build Coastguard Worker return workload;
366*89c4ff92SAndroid Build Coastguard Worker }
367*89c4ff92SAndroid Build Coastguard Worker
368*89c4ff92SAndroid Build Coastguard Worker template <typename BatchNormalizationWorkloadType, armnn::DataType DataType>
CreateBatchNormalizationWorkloadTest(armnn::IWorkloadFactory & factory,armnn::Graph & graph,DataLayout dataLayout=DataLayout::NCHW)369*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<BatchNormalizationWorkloadType> CreateBatchNormalizationWorkloadTest(
370*89c4ff92SAndroid Build Coastguard Worker armnn::IWorkloadFactory& factory, armnn::Graph& graph, DataLayout dataLayout = DataLayout::NCHW)
371*89c4ff92SAndroid Build Coastguard Worker {
372*89c4ff92SAndroid Build Coastguard Worker TensorShape tensorShape;
373*89c4ff92SAndroid Build Coastguard Worker switch (dataLayout)
374*89c4ff92SAndroid Build Coastguard Worker {
375*89c4ff92SAndroid Build Coastguard Worker case DataLayout::NHWC:
376*89c4ff92SAndroid Build Coastguard Worker tensorShape = { 2, 4, 4, 3 };
377*89c4ff92SAndroid Build Coastguard Worker break;
378*89c4ff92SAndroid Build Coastguard Worker case DataLayout::NCHW:
379*89c4ff92SAndroid Build Coastguard Worker default:
380*89c4ff92SAndroid Build Coastguard Worker tensorShape = { 2, 3, 4, 4 };
381*89c4ff92SAndroid Build Coastguard Worker }
382*89c4ff92SAndroid Build Coastguard Worker
383*89c4ff92SAndroid Build Coastguard Worker // Creates the layer we're testing.
384*89c4ff92SAndroid Build Coastguard Worker BatchNormalizationDescriptor layerDesc;
385*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_Eps = 0.05f;
386*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_DataLayout = dataLayout;
387*89c4ff92SAndroid Build Coastguard Worker
388*89c4ff92SAndroid Build Coastguard Worker BatchNormalizationLayer* const layer = graph.AddLayer<BatchNormalizationLayer>(layerDesc, "layer");
389*89c4ff92SAndroid Build Coastguard Worker
390*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo weightInfo({3}, DataType);
391*89c4ff92SAndroid Build Coastguard Worker layer->m_Mean = std::make_unique<ScopedTensorHandle>(weightInfo);
392*89c4ff92SAndroid Build Coastguard Worker layer->m_Variance = std::make_unique<ScopedTensorHandle>(weightInfo);
393*89c4ff92SAndroid Build Coastguard Worker layer->m_Beta = std::make_unique<ScopedTensorHandle>(weightInfo);
394*89c4ff92SAndroid Build Coastguard Worker layer->m_Gamma = std::make_unique<ScopedTensorHandle>(weightInfo);
395*89c4ff92SAndroid Build Coastguard Worker layer->m_Mean->Allocate();
396*89c4ff92SAndroid Build Coastguard Worker layer->m_Variance->Allocate();
397*89c4ff92SAndroid Build Coastguard Worker layer->m_Beta->Allocate();
398*89c4ff92SAndroid Build Coastguard Worker layer->m_Gamma->Allocate();
399*89c4ff92SAndroid Build Coastguard Worker
400*89c4ff92SAndroid Build Coastguard Worker // Creates extra layers.
401*89c4ff92SAndroid Build Coastguard Worker Layer* const input = graph.AddLayer<InputLayer>(0, "input");
402*89c4ff92SAndroid Build Coastguard Worker Layer* const output = graph.AddLayer<OutputLayer>(0, "output");
403*89c4ff92SAndroid Build Coastguard Worker
404*89c4ff92SAndroid Build Coastguard Worker // Connects up.
405*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo tensorInfo(tensorShape, DataType);
406*89c4ff92SAndroid Build Coastguard Worker Connect(input, layer, tensorInfo);
407*89c4ff92SAndroid Build Coastguard Worker Connect(layer, output, tensorInfo);
408*89c4ff92SAndroid Build Coastguard Worker CreateTensorHandles(graph, factory);
409*89c4ff92SAndroid Build Coastguard Worker
410*89c4ff92SAndroid Build Coastguard Worker // Makes the workload and checks it.
411*89c4ff92SAndroid Build Coastguard Worker auto workload = MakeAndCheckWorkload<BatchNormalizationWorkloadType>(*layer, factory);
412*89c4ff92SAndroid Build Coastguard Worker BatchNormalizationQueueDescriptor queueDescriptor = workload->GetData();
413*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Parameters.m_Eps == 0.05f);
414*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Inputs.size() == 1);
415*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Outputs.size() == 1);
416*89c4ff92SAndroid Build Coastguard Worker CHECK((queueDescriptor.m_Mean->GetTensorInfo() == TensorInfo({3}, DataType)));
417*89c4ff92SAndroid Build Coastguard Worker CHECK((queueDescriptor.m_Variance->GetTensorInfo() == TensorInfo({3}, DataType)));
418*89c4ff92SAndroid Build Coastguard Worker CHECK((queueDescriptor.m_Gamma->GetTensorInfo() == TensorInfo({3}, DataType)));
419*89c4ff92SAndroid Build Coastguard Worker CHECK((queueDescriptor.m_Beta->GetTensorInfo() == TensorInfo({3}, DataType)));
420*89c4ff92SAndroid Build Coastguard Worker CHECK((queueDescriptor.m_Parameters.m_DataLayout == dataLayout));
421*89c4ff92SAndroid Build Coastguard Worker
422*89c4ff92SAndroid Build Coastguard Worker // Returns so we can do extra, backend-specific tests.
423*89c4ff92SAndroid Build Coastguard Worker return workload;
424*89c4ff92SAndroid Build Coastguard Worker }
425*89c4ff92SAndroid Build Coastguard Worker
426*89c4ff92SAndroid Build Coastguard Worker template <typename BatchNormalizationWorkloadType, armnn::DataType DataType>
CreateBatchNormalizationWithBlobWorkloadTest(armnn::IWorkloadFactory & factory,armnn::Graph & graph,DataLayout dataLayout=DataLayout::NCHW)427*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<BatchNormalizationWorkloadType> CreateBatchNormalizationWithBlobWorkloadTest(
428*89c4ff92SAndroid Build Coastguard Worker armnn::IWorkloadFactory& factory, armnn::Graph& graph, DataLayout dataLayout = DataLayout::NCHW)
429*89c4ff92SAndroid Build Coastguard Worker {
430*89c4ff92SAndroid Build Coastguard Worker TensorShape tensorShape;
431*89c4ff92SAndroid Build Coastguard Worker switch (dataLayout)
432*89c4ff92SAndroid Build Coastguard Worker {
433*89c4ff92SAndroid Build Coastguard Worker case DataLayout::NHWC:
434*89c4ff92SAndroid Build Coastguard Worker tensorShape = { 2, 4, 4, 3 };
435*89c4ff92SAndroid Build Coastguard Worker break;
436*89c4ff92SAndroid Build Coastguard Worker case DataLayout::NCHW:
437*89c4ff92SAndroid Build Coastguard Worker default:
438*89c4ff92SAndroid Build Coastguard Worker tensorShape = { 2, 3, 4, 4 };
439*89c4ff92SAndroid Build Coastguard Worker }
440*89c4ff92SAndroid Build Coastguard Worker
441*89c4ff92SAndroid Build Coastguard Worker // Creates the layer we're testing.
442*89c4ff92SAndroid Build Coastguard Worker BatchNormalizationDescriptor layerDesc;
443*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_Eps = 0.05f;
444*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_DataLayout = dataLayout;
445*89c4ff92SAndroid Build Coastguard Worker
446*89c4ff92SAndroid Build Coastguard Worker BatchNormalizationLayer* const layer = graph.AddLayer<BatchNormalizationLayer>(layerDesc, "layer");
447*89c4ff92SAndroid Build Coastguard Worker
448*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo weightInfo({3}, DataType);
449*89c4ff92SAndroid Build Coastguard Worker layer->m_Mean = std::make_unique<ScopedTensorHandle>(weightInfo);
450*89c4ff92SAndroid Build Coastguard Worker layer->m_Variance = std::make_unique<ScopedTensorHandle>(weightInfo);
451*89c4ff92SAndroid Build Coastguard Worker layer->m_Beta = std::make_unique<ScopedTensorHandle>(weightInfo);
452*89c4ff92SAndroid Build Coastguard Worker layer->m_Gamma = std::make_unique<ScopedTensorHandle>(weightInfo);
453*89c4ff92SAndroid Build Coastguard Worker layer->m_Mean->Allocate();
454*89c4ff92SAndroid Build Coastguard Worker layer->m_Variance->Allocate();
455*89c4ff92SAndroid Build Coastguard Worker layer->m_Beta->Allocate();
456*89c4ff92SAndroid Build Coastguard Worker layer->m_Gamma->Allocate();
457*89c4ff92SAndroid Build Coastguard Worker
458*89c4ff92SAndroid Build Coastguard Worker auto activationDesc = std::make_shared<ActivationDescriptor>();
459*89c4ff92SAndroid Build Coastguard Worker activationDesc->m_A = 10.0f;
460*89c4ff92SAndroid Build Coastguard Worker activationDesc->m_B = 5.0f;
461*89c4ff92SAndroid Build Coastguard Worker activationDesc->m_Function = armnn::ActivationFunction::BoundedReLu;
462*89c4ff92SAndroid Build Coastguard Worker
463*89c4ff92SAndroid Build Coastguard Worker layer->SetAdditionalInfoForObject(activationDesc);
464*89c4ff92SAndroid Build Coastguard Worker
465*89c4ff92SAndroid Build Coastguard Worker // Check that the additional information can be queried from the layer
466*89c4ff92SAndroid Build Coastguard Worker std::shared_ptr<ActivationDescriptor> activationDescPtr = layer->GetAdditionalInformation<ActivationDescriptor>();
467*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(static_cast<float>(activationDescPtr->m_A) == 10.0f);
468*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(static_cast<float>(activationDescPtr->m_B) == 5.0f);
469*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(
470*89c4ff92SAndroid Build Coastguard Worker static_cast<ActivationFunction>(activationDescPtr->m_Function) == armnn::ActivationFunction::BoundedReLu
471*89c4ff92SAndroid Build Coastguard Worker );
472*89c4ff92SAndroid Build Coastguard Worker
473*89c4ff92SAndroid Build Coastguard Worker // Creates extra layers.
474*89c4ff92SAndroid Build Coastguard Worker Layer* const input = graph.AddLayer<InputLayer>(0, "input");
475*89c4ff92SAndroid Build Coastguard Worker Layer* const output = graph.AddLayer<OutputLayer>(0, "output");
476*89c4ff92SAndroid Build Coastguard Worker
477*89c4ff92SAndroid Build Coastguard Worker // Connects up.
478*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo tensorInfo(tensorShape, DataType);
479*89c4ff92SAndroid Build Coastguard Worker Connect(input, layer, tensorInfo);
480*89c4ff92SAndroid Build Coastguard Worker Connect(layer, output, tensorInfo);
481*89c4ff92SAndroid Build Coastguard Worker CreateTensorHandles(graph, factory);
482*89c4ff92SAndroid Build Coastguard Worker
483*89c4ff92SAndroid Build Coastguard Worker // Makes the workload and checks it.
484*89c4ff92SAndroid Build Coastguard Worker auto workload = MakeAndCheckWorkload<BatchNormalizationWorkloadType>(*layer, factory);
485*89c4ff92SAndroid Build Coastguard Worker BatchNormalizationQueueDescriptor queueDescriptor = workload->GetData();
486*89c4ff92SAndroid Build Coastguard Worker const ActivationDescriptor* queueDescBlobPtr = queueDescriptor.GetAdditionalInformation<ActivationDescriptor>();
487*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(queueDescBlobPtr);
488*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(static_cast<float>(queueDescBlobPtr->m_A) == 10.0f);
489*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(static_cast<float>(queueDescBlobPtr->m_B) == 5.0f);
490*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(
491*89c4ff92SAndroid Build Coastguard Worker static_cast<ActivationFunction>(queueDescBlobPtr->m_Function) == armnn::ActivationFunction::BoundedReLu
492*89c4ff92SAndroid Build Coastguard Worker );
493*89c4ff92SAndroid Build Coastguard Worker
494*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Parameters.m_Eps == 0.05f);
495*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Inputs.size() == 1);
496*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Outputs.size() == 1);
497*89c4ff92SAndroid Build Coastguard Worker CHECK((queueDescriptor.m_Mean->GetTensorInfo() == TensorInfo({3}, DataType)));
498*89c4ff92SAndroid Build Coastguard Worker CHECK((queueDescriptor.m_Variance->GetTensorInfo() == TensorInfo({3}, DataType)));
499*89c4ff92SAndroid Build Coastguard Worker CHECK((queueDescriptor.m_Gamma->GetTensorInfo() == TensorInfo({3}, DataType)));
500*89c4ff92SAndroid Build Coastguard Worker CHECK((queueDescriptor.m_Beta->GetTensorInfo() == TensorInfo({3}, DataType)));
501*89c4ff92SAndroid Build Coastguard Worker CHECK((queueDescriptor.m_Parameters.m_DataLayout == dataLayout));
502*89c4ff92SAndroid Build Coastguard Worker
503*89c4ff92SAndroid Build Coastguard Worker // Returns so we can do extra, backend-specific tests.
504*89c4ff92SAndroid Build Coastguard Worker return workload;
505*89c4ff92SAndroid Build Coastguard Worker }
506*89c4ff92SAndroid Build Coastguard Worker
507*89c4ff92SAndroid Build Coastguard Worker template <typename Convolution2dWorkload, armnn::DataType DataType>
CreateConvolution2dWorkloadTest(armnn::IWorkloadFactory & factory,armnn::Graph & graph,DataLayout dataLayout=DataLayout::NCHW,const ModelOptions & modelOptions={})508*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<Convolution2dWorkload> CreateConvolution2dWorkloadTest(armnn::IWorkloadFactory& factory,
509*89c4ff92SAndroid Build Coastguard Worker armnn::Graph& graph,
510*89c4ff92SAndroid Build Coastguard Worker DataLayout dataLayout = DataLayout::NCHW,
511*89c4ff92SAndroid Build Coastguard Worker const ModelOptions& modelOptions = {})
512*89c4ff92SAndroid Build Coastguard Worker {
513*89c4ff92SAndroid Build Coastguard Worker // Creates the layer we're testing.
514*89c4ff92SAndroid Build Coastguard Worker Convolution2dDescriptor layerDesc;
515*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_PadLeft = 3;
516*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_PadRight = 3;
517*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_PadTop = 1;
518*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_PadBottom = 1;
519*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_StrideX = 2;
520*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_StrideY = 4;
521*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_BiasEnabled = false;
522*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_DataLayout = dataLayout;
523*89c4ff92SAndroid Build Coastguard Worker
524*89c4ff92SAndroid Build Coastguard Worker float inputsQScale = 1.0f;
525*89c4ff92SAndroid Build Coastguard Worker float outputQScale = DataType == armnn::DataType::QAsymmU8 ? 2.0f : 1.0;
526*89c4ff92SAndroid Build Coastguard Worker
527*89c4ff92SAndroid Build Coastguard Worker Convolution2dLayer* const layer = graph.AddLayer<Convolution2dLayer>(layerDesc, "layer");
528*89c4ff92SAndroid Build Coastguard Worker
529*89c4ff92SAndroid Build Coastguard Worker TensorShape weightShape = (dataLayout == DataLayout::NCHW) ? TensorShape{2, 3, 5, 3} : TensorShape{2, 5, 3, 3};
530*89c4ff92SAndroid Build Coastguard Worker TensorShape inputShape = (dataLayout == DataLayout::NCHW) ? TensorShape{2, 3, 8, 16} : TensorShape{2, 8, 16, 3};
531*89c4ff92SAndroid Build Coastguard Worker TensorShape outputShape = (dataLayout == DataLayout::NCHW) ? TensorShape{2, 2, 2, 10} : TensorShape{2, 2, 10, 2};
532*89c4ff92SAndroid Build Coastguard Worker
533*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo weightsTensorInfo(weightShape, DataType, inputsQScale);
534*89c4ff92SAndroid Build Coastguard Worker weightsTensorInfo.SetConstant();
535*89c4ff92SAndroid Build Coastguard Worker
536*89c4ff92SAndroid Build Coastguard Worker // Creates extra layers.
537*89c4ff92SAndroid Build Coastguard Worker Layer* const input = graph.AddLayer<InputLayer>(0, "input");
538*89c4ff92SAndroid Build Coastguard Worker auto const weights = graph.AddLayer<ConstantLayer>("weights");
539*89c4ff92SAndroid Build Coastguard Worker Layer* const output = graph.AddLayer<OutputLayer>(0, "output");
540*89c4ff92SAndroid Build Coastguard Worker
541*89c4ff92SAndroid Build Coastguard Worker weights->m_LayerOutput = std::make_unique<ScopedTensorHandle>(weightsTensorInfo);
542*89c4ff92SAndroid Build Coastguard Worker weights->m_LayerOutput->Allocate();
543*89c4ff92SAndroid Build Coastguard Worker
544*89c4ff92SAndroid Build Coastguard Worker // Connects up.
545*89c4ff92SAndroid Build Coastguard Worker Connect(input, layer, TensorInfo(inputShape, DataType, inputsQScale));
546*89c4ff92SAndroid Build Coastguard Worker Connect(weights, layer, weightsTensorInfo, 0, 1);
547*89c4ff92SAndroid Build Coastguard Worker Connect(layer, output, TensorInfo(outputShape, DataType, outputQScale));
548*89c4ff92SAndroid Build Coastguard Worker CreateTensorHandles(graph, factory);
549*89c4ff92SAndroid Build Coastguard Worker
550*89c4ff92SAndroid Build Coastguard Worker // Makes the workload and checks it.
551*89c4ff92SAndroid Build Coastguard Worker auto workload = MakeAndCheckWorkload<Convolution2dWorkload>(*layer, factory, modelOptions);
552*89c4ff92SAndroid Build Coastguard Worker
553*89c4ff92SAndroid Build Coastguard Worker Convolution2dQueueDescriptor queueDescriptor = workload->GetData();
554*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Parameters.m_StrideX == 2);
555*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Parameters.m_StrideY == 4);
556*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Parameters.m_PadLeft == 3);
557*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Parameters.m_PadRight == 3);
558*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Parameters.m_PadTop == 1);
559*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Parameters.m_PadBottom == 1);
560*89c4ff92SAndroid Build Coastguard Worker CHECK(!queueDescriptor.m_Parameters.m_BiasEnabled);
561*89c4ff92SAndroid Build Coastguard Worker CHECK((queueDescriptor.m_Parameters.m_DataLayout == dataLayout));
562*89c4ff92SAndroid Build Coastguard Worker
563*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Inputs.size() == 2);
564*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Outputs.size() == 1);
565*89c4ff92SAndroid Build Coastguard Worker
566*89c4ff92SAndroid Build Coastguard Worker // Returns so we can do extra, backend-specific tests.
567*89c4ff92SAndroid Build Coastguard Worker return workload;
568*89c4ff92SAndroid Build Coastguard Worker }
569*89c4ff92SAndroid Build Coastguard Worker
570*89c4ff92SAndroid Build Coastguard Worker template<typename Convolution2dWorkload, armnn::DataType DataType>
CreateConvolution2dFusedActivationWithBlobWorkloadTest(armnn::IWorkloadFactory & factory,armnn::Graph & graph,DataLayout dataLayout=DataLayout::NCHW,const ModelOptions & modelOptions={})571*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<Convolution2dWorkload> CreateConvolution2dFusedActivationWithBlobWorkloadTest(
572*89c4ff92SAndroid Build Coastguard Worker armnn::IWorkloadFactory& factory,
573*89c4ff92SAndroid Build Coastguard Worker armnn::Graph& graph,
574*89c4ff92SAndroid Build Coastguard Worker DataLayout dataLayout = DataLayout::NCHW,
575*89c4ff92SAndroid Build Coastguard Worker const ModelOptions& modelOptions = {})
576*89c4ff92SAndroid Build Coastguard Worker {
577*89c4ff92SAndroid Build Coastguard Worker // Creates the layer we're testing.
578*89c4ff92SAndroid Build Coastguard Worker Convolution2dDescriptor layerDesc;
579*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_PadLeft = 3;
580*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_PadRight = 3;
581*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_PadTop = 1;
582*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_PadBottom = 1;
583*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_StrideX = 2;
584*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_StrideY = 4;
585*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_BiasEnabled = true;
586*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_DataLayout = dataLayout;
587*89c4ff92SAndroid Build Coastguard Worker
588*89c4ff92SAndroid Build Coastguard Worker float inputsQScale = 1.0f;
589*89c4ff92SAndroid Build Coastguard Worker float outputQScale = DataType == armnn::DataType::QAsymmU8 ? 2.0f : 1.0;
590*89c4ff92SAndroid Build Coastguard Worker
591*89c4ff92SAndroid Build Coastguard Worker Convolution2dLayer* const layer = graph.AddLayer<Convolution2dLayer>(layerDesc, "layer");
592*89c4ff92SAndroid Build Coastguard Worker
593*89c4ff92SAndroid Build Coastguard Worker TensorShape weightShape = (dataLayout == DataLayout::NCHW) ? TensorShape{2, 3, 5, 3} : TensorShape{2, 5, 3, 3};
594*89c4ff92SAndroid Build Coastguard Worker TensorShape inputShape = (dataLayout == DataLayout::NCHW) ? TensorShape{2, 3, 8, 16} : TensorShape{2, 8, 16, 3};
595*89c4ff92SAndroid Build Coastguard Worker TensorShape outputShape = (dataLayout == DataLayout::NCHW) ? TensorShape{2, 2, 2, 10} : TensorShape{2, 2, 10, 2};
596*89c4ff92SAndroid Build Coastguard Worker
597*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo weightsTensorInfo(weightShape, DataType, inputsQScale);
598*89c4ff92SAndroid Build Coastguard Worker weightsTensorInfo.SetConstant();
599*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo biasTensorInfo({2}, DataType, inputsQScale);
600*89c4ff92SAndroid Build Coastguard Worker biasTensorInfo.SetConstant();
601*89c4ff92SAndroid Build Coastguard Worker
602*89c4ff92SAndroid Build Coastguard Worker auto activationDesc = std::make_shared<ActivationDescriptor>();
603*89c4ff92SAndroid Build Coastguard Worker activationDesc->m_A = 10.0f;
604*89c4ff92SAndroid Build Coastguard Worker activationDesc->m_B = 5.0f;
605*89c4ff92SAndroid Build Coastguard Worker activationDesc->m_Function = armnn::ActivationFunction::BoundedReLu;
606*89c4ff92SAndroid Build Coastguard Worker
607*89c4ff92SAndroid Build Coastguard Worker layer->SetAdditionalInfoForObject(activationDesc);
608*89c4ff92SAndroid Build Coastguard Worker
609*89c4ff92SAndroid Build Coastguard Worker // Check that the additional information can be queried from the layer
610*89c4ff92SAndroid Build Coastguard Worker std::shared_ptr<ActivationDescriptor> activationDescPtr = layer->GetAdditionalInformation<ActivationDescriptor>();
611*89c4ff92SAndroid Build Coastguard Worker
612*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(static_cast<float>(activationDescPtr->m_A) == 10.0f);
613*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(static_cast<float>(activationDescPtr->m_B) == 5.0f);
614*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(
615*89c4ff92SAndroid Build Coastguard Worker static_cast<ActivationFunction>(activationDescPtr->m_Function) == armnn::ActivationFunction::BoundedReLu
616*89c4ff92SAndroid Build Coastguard Worker );
617*89c4ff92SAndroid Build Coastguard Worker
618*89c4ff92SAndroid Build Coastguard Worker // Creates extra layers.
619*89c4ff92SAndroid Build Coastguard Worker Layer* const input = graph.AddLayer<InputLayer>(0, "input");
620*89c4ff92SAndroid Build Coastguard Worker auto const weights = graph.AddLayer<ConstantLayer>("weights");
621*89c4ff92SAndroid Build Coastguard Worker auto const bias = graph.AddLayer<ConstantLayer>("bias");
622*89c4ff92SAndroid Build Coastguard Worker Layer* const output = graph.AddLayer<OutputLayer>(0, "output");
623*89c4ff92SAndroid Build Coastguard Worker
624*89c4ff92SAndroid Build Coastguard Worker weights->m_LayerOutput = std::make_unique<ScopedTensorHandle>(weightsTensorInfo);
625*89c4ff92SAndroid Build Coastguard Worker weights->m_LayerOutput->Allocate();
626*89c4ff92SAndroid Build Coastguard Worker bias->m_LayerOutput = std::make_unique<ScopedTensorHandle>(biasTensorInfo);
627*89c4ff92SAndroid Build Coastguard Worker bias->m_LayerOutput->Allocate();
628*89c4ff92SAndroid Build Coastguard Worker
629*89c4ff92SAndroid Build Coastguard Worker // Connects up.
630*89c4ff92SAndroid Build Coastguard Worker Connect(input, layer, TensorInfo(inputShape, DataType, inputsQScale));
631*89c4ff92SAndroid Build Coastguard Worker Connect(weights, layer, weightsTensorInfo, 0, 1);
632*89c4ff92SAndroid Build Coastguard Worker Connect(bias, layer, biasTensorInfo, 0, 2);
633*89c4ff92SAndroid Build Coastguard Worker Connect(layer, output, TensorInfo(outputShape, DataType, outputQScale));
634*89c4ff92SAndroid Build Coastguard Worker CreateTensorHandles(graph, factory);
635*89c4ff92SAndroid Build Coastguard Worker
636*89c4ff92SAndroid Build Coastguard Worker // Makes the workload and checks it.
637*89c4ff92SAndroid Build Coastguard Worker auto workload = MakeAndCheckWorkload<Convolution2dWorkload>(*layer, factory, modelOptions);
638*89c4ff92SAndroid Build Coastguard Worker
639*89c4ff92SAndroid Build Coastguard Worker Convolution2dQueueDescriptor queueDescriptor = workload->GetData();
640*89c4ff92SAndroid Build Coastguard Worker const ActivationDescriptor* queueDescBlobPtr = queueDescriptor.GetAdditionalInformation<ActivationDescriptor>();
641*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(queueDescBlobPtr);
642*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(static_cast<float>(queueDescBlobPtr->m_A) == 10.0f);
643*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(static_cast<float>(queueDescBlobPtr->m_B) == 5.0f);
644*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(
645*89c4ff92SAndroid Build Coastguard Worker static_cast<ActivationFunction>(queueDescBlobPtr->m_Function) == armnn::ActivationFunction::BoundedReLu
646*89c4ff92SAndroid Build Coastguard Worker );
647*89c4ff92SAndroid Build Coastguard Worker
648*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Parameters.m_StrideX == 2);
649*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Parameters.m_StrideY == 4);
650*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Parameters.m_PadLeft == 3);
651*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Parameters.m_PadRight == 3);
652*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Parameters.m_PadTop == 1);
653*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Parameters.m_PadBottom == 1);
654*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Parameters.m_BiasEnabled);
655*89c4ff92SAndroid Build Coastguard Worker CHECK((queueDescriptor.m_Parameters.m_DataLayout == dataLayout));
656*89c4ff92SAndroid Build Coastguard Worker
657*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Outputs.size() == 1);
658*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Inputs.size() == 3);
659*89c4ff92SAndroid Build Coastguard Worker
660*89c4ff92SAndroid Build Coastguard Worker // Returns so we can do extra, backend-specific tests.
661*89c4ff92SAndroid Build Coastguard Worker return workload;
662*89c4ff92SAndroid Build Coastguard Worker }
663*89c4ff92SAndroid Build Coastguard Worker
664*89c4ff92SAndroid Build Coastguard Worker template <typename Convolution2dWorkload, armnn::DataType DataType>
CreateConvolution2dWorkloadFastMathTest(armnn::IWorkloadFactory & factory,armnn::Graph & graph,DataLayout dataLayout=DataLayout::NCHW,const ModelOptions & modelOptions={})665*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<Convolution2dWorkload> CreateConvolution2dWorkloadFastMathTest(armnn::IWorkloadFactory& factory,
666*89c4ff92SAndroid Build Coastguard Worker armnn::Graph& graph,
667*89c4ff92SAndroid Build Coastguard Worker DataLayout dataLayout = DataLayout::NCHW,
668*89c4ff92SAndroid Build Coastguard Worker const ModelOptions& modelOptions = {})
669*89c4ff92SAndroid Build Coastguard Worker {
670*89c4ff92SAndroid Build Coastguard Worker // Creates the layer we're testing.
671*89c4ff92SAndroid Build Coastguard Worker Convolution2dDescriptor layerDesc;
672*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_PadLeft = 0;
673*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_PadRight = 0;
674*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_PadTop = 0;
675*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_PadBottom = 0;
676*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_StrideX = 1;
677*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_StrideY = 1;
678*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_BiasEnabled = true;
679*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_DataLayout = dataLayout;
680*89c4ff92SAndroid Build Coastguard Worker
681*89c4ff92SAndroid Build Coastguard Worker float inputsQScale = 1.0f;
682*89c4ff92SAndroid Build Coastguard Worker float outputQScale = DataType == armnn::DataType::QAsymmU8 ? 2.0f : 1.0;
683*89c4ff92SAndroid Build Coastguard Worker
684*89c4ff92SAndroid Build Coastguard Worker Convolution2dLayer* const layer = graph.AddLayer<Convolution2dLayer>(layerDesc, "layer");
685*89c4ff92SAndroid Build Coastguard Worker
686*89c4ff92SAndroid Build Coastguard Worker TensorShape weightShape = TensorShape{ 32, 32, 3, 3 };
687*89c4ff92SAndroid Build Coastguard Worker TensorShape biasShape = TensorShape{ 32 };
688*89c4ff92SAndroid Build Coastguard Worker TensorShape inputShape = TensorShape{ 1, 32, 149, 149 };
689*89c4ff92SAndroid Build Coastguard Worker TensorShape outputShape = TensorShape{ 1, 32, 147, 147 };
690*89c4ff92SAndroid Build Coastguard Worker
691*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo weightsTensorInfo(weightShape, DataType, inputsQScale);
692*89c4ff92SAndroid Build Coastguard Worker weightsTensorInfo.SetConstant();
693*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo biasTensorInfo(biasShape, DataType, inputsQScale);
694*89c4ff92SAndroid Build Coastguard Worker biasTensorInfo.SetConstant();
695*89c4ff92SAndroid Build Coastguard Worker
696*89c4ff92SAndroid Build Coastguard Worker // Creates extra layers.
697*89c4ff92SAndroid Build Coastguard Worker Layer* const input = graph.AddLayer<InputLayer>(0, "input");
698*89c4ff92SAndroid Build Coastguard Worker auto const weights = graph.AddLayer<ConstantLayer>("weights");
699*89c4ff92SAndroid Build Coastguard Worker auto const bias = graph.AddLayer<ConstantLayer>("bias");
700*89c4ff92SAndroid Build Coastguard Worker Layer* const output = graph.AddLayer<OutputLayer>(0, "output");
701*89c4ff92SAndroid Build Coastguard Worker
702*89c4ff92SAndroid Build Coastguard Worker // Connects up.
703*89c4ff92SAndroid Build Coastguard Worker Connect(input, layer, TensorInfo(inputShape, DataType));
704*89c4ff92SAndroid Build Coastguard Worker Connect(weights, layer, weightsTensorInfo, 0, 1);
705*89c4ff92SAndroid Build Coastguard Worker Connect(bias, layer, biasTensorInfo, 0, 2);
706*89c4ff92SAndroid Build Coastguard Worker Connect(layer, output, TensorInfo(outputShape, DataType, outputQScale));
707*89c4ff92SAndroid Build Coastguard Worker CreateTensorHandles(graph, factory);
708*89c4ff92SAndroid Build Coastguard Worker
709*89c4ff92SAndroid Build Coastguard Worker // Makes the workload and checks it.
710*89c4ff92SAndroid Build Coastguard Worker auto workload = MakeAndCheckWorkload<Convolution2dWorkload>(*layer, factory, modelOptions);
711*89c4ff92SAndroid Build Coastguard Worker
712*89c4ff92SAndroid Build Coastguard Worker Convolution2dQueueDescriptor queueDescriptor = workload->GetData();
713*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Parameters.m_StrideX == 1);
714*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Parameters.m_StrideY == 1);
715*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Parameters.m_PadLeft == 0);
716*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Parameters.m_PadRight == 0);
717*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Parameters.m_PadTop == 0);
718*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Parameters.m_PadBottom == 0);
719*89c4ff92SAndroid Build Coastguard Worker CHECK((queueDescriptor.m_Parameters.m_DataLayout == dataLayout));
720*89c4ff92SAndroid Build Coastguard Worker
721*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Inputs.size() == 3);
722*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Outputs.size() == 1);
723*89c4ff92SAndroid Build Coastguard Worker
724*89c4ff92SAndroid Build Coastguard Worker // Returns so we can do extra, backend-specific tests.
725*89c4ff92SAndroid Build Coastguard Worker return workload;
726*89c4ff92SAndroid Build Coastguard Worker }
727*89c4ff92SAndroid Build Coastguard Worker
728*89c4ff92SAndroid Build Coastguard Worker template <typename LstmWorkload>
CreateLstmWorkloadTest(armnn::IWorkloadFactory & factory,armnn::Graph & graph)729*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<LstmWorkload> CreateLstmWorkloadTest(armnn::IWorkloadFactory& factory, armnn::Graph& graph)
730*89c4ff92SAndroid Build Coastguard Worker {
731*89c4ff92SAndroid Build Coastguard Worker // This parameter setting is for withCifgWithPeepholeNoProjection
732*89c4ff92SAndroid Build Coastguard Worker LstmDescriptor layerDesc;
733*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_ActivationFunc = 4;
734*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_ClippingThresCell = 0.0f;
735*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_ClippingThresProj = 0.0f;
736*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_CifgEnabled = true;
737*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_PeepholeEnabled = true;
738*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_ProjectionEnabled = false;
739*89c4ff92SAndroid Build Coastguard Worker
740*89c4ff92SAndroid Build Coastguard Worker LstmLayer* const layer = graph.AddLayer<LstmLayer>(layerDesc, "layer");
741*89c4ff92SAndroid Build Coastguard Worker unsigned int batchSize = 2;
742*89c4ff92SAndroid Build Coastguard Worker unsigned int inputSize = 2;
743*89c4ff92SAndroid Build Coastguard Worker unsigned int numUnits = 4;
744*89c4ff92SAndroid Build Coastguard Worker unsigned int outputSize = 4;
745*89c4ff92SAndroid Build Coastguard Worker
746*89c4ff92SAndroid Build Coastguard Worker layer->m_BasicParameters.m_InputToForgetWeights = std::make_unique<ScopedTensorHandle>
747*89c4ff92SAndroid Build Coastguard Worker (TensorInfo({ numUnits, inputSize }, DataType::Float32));
748*89c4ff92SAndroid Build Coastguard Worker layer->m_BasicParameters.m_InputToCellWeights = std::make_unique<ScopedTensorHandle>
749*89c4ff92SAndroid Build Coastguard Worker (TensorInfo({ numUnits, inputSize }, DataType::Float32));
750*89c4ff92SAndroid Build Coastguard Worker layer->m_BasicParameters.m_InputToOutputWeights = std::make_unique<ScopedTensorHandle>
751*89c4ff92SAndroid Build Coastguard Worker (TensorInfo({ numUnits, inputSize }, DataType::Float32));
752*89c4ff92SAndroid Build Coastguard Worker layer->m_BasicParameters.m_RecurrentToForgetWeights = std::make_unique<ScopedTensorHandle>
753*89c4ff92SAndroid Build Coastguard Worker (TensorInfo({ numUnits, outputSize }, DataType::Float32));
754*89c4ff92SAndroid Build Coastguard Worker layer->m_BasicParameters.m_RecurrentToCellWeights = std::make_unique<ScopedTensorHandle>
755*89c4ff92SAndroid Build Coastguard Worker (TensorInfo({ numUnits, outputSize }, DataType::Float32));
756*89c4ff92SAndroid Build Coastguard Worker layer->m_BasicParameters.m_RecurrentToOutputWeights = std::make_unique<ScopedTensorHandle>
757*89c4ff92SAndroid Build Coastguard Worker (TensorInfo({ numUnits, outputSize }, DataType::Float32));
758*89c4ff92SAndroid Build Coastguard Worker layer->m_BasicParameters.m_ForgetGateBias = std::make_unique<ScopedTensorHandle>
759*89c4ff92SAndroid Build Coastguard Worker (TensorInfo({ numUnits }, DataType::Float32));
760*89c4ff92SAndroid Build Coastguard Worker layer->m_BasicParameters.m_CellBias = std::make_unique<ScopedTensorHandle>
761*89c4ff92SAndroid Build Coastguard Worker (TensorInfo({ numUnits }, DataType::Float32));
762*89c4ff92SAndroid Build Coastguard Worker layer->m_BasicParameters.m_OutputGateBias = std::make_unique<ScopedTensorHandle>
763*89c4ff92SAndroid Build Coastguard Worker (TensorInfo({ numUnits }, DataType::Float32));
764*89c4ff92SAndroid Build Coastguard Worker
765*89c4ff92SAndroid Build Coastguard Worker layer->m_BasicParameters.m_InputToForgetWeights->Allocate();
766*89c4ff92SAndroid Build Coastguard Worker layer->m_BasicParameters.m_InputToCellWeights->Allocate();
767*89c4ff92SAndroid Build Coastguard Worker layer->m_BasicParameters.m_InputToOutputWeights->Allocate();
768*89c4ff92SAndroid Build Coastguard Worker layer->m_BasicParameters.m_RecurrentToForgetWeights->Allocate();
769*89c4ff92SAndroid Build Coastguard Worker layer->m_BasicParameters.m_RecurrentToCellWeights->Allocate();
770*89c4ff92SAndroid Build Coastguard Worker layer->m_BasicParameters.m_RecurrentToOutputWeights->Allocate();
771*89c4ff92SAndroid Build Coastguard Worker layer->m_BasicParameters.m_ForgetGateBias->Allocate();
772*89c4ff92SAndroid Build Coastguard Worker layer->m_BasicParameters.m_CellBias->Allocate();
773*89c4ff92SAndroid Build Coastguard Worker layer->m_BasicParameters.m_OutputGateBias->Allocate();
774*89c4ff92SAndroid Build Coastguard Worker
775*89c4ff92SAndroid Build Coastguard Worker
776*89c4ff92SAndroid Build Coastguard Worker if (layerDesc.m_PeepholeEnabled)
777*89c4ff92SAndroid Build Coastguard Worker {
778*89c4ff92SAndroid Build Coastguard Worker layer->m_PeepholeParameters.m_CellToForgetWeights = std::make_unique<ScopedTensorHandle>
779*89c4ff92SAndroid Build Coastguard Worker (TensorInfo({ numUnits }, DataType::Float32));
780*89c4ff92SAndroid Build Coastguard Worker layer->m_PeepholeParameters.m_CellToOutputWeights = std::make_unique<ScopedTensorHandle>
781*89c4ff92SAndroid Build Coastguard Worker (TensorInfo({ numUnits }, DataType::Float32));
782*89c4ff92SAndroid Build Coastguard Worker layer->m_PeepholeParameters.m_CellToForgetWeights->Allocate();
783*89c4ff92SAndroid Build Coastguard Worker layer->m_PeepholeParameters.m_CellToOutputWeights->Allocate();
784*89c4ff92SAndroid Build Coastguard Worker }
785*89c4ff92SAndroid Build Coastguard Worker
786*89c4ff92SAndroid Build Coastguard Worker // create input and output layers
787*89c4ff92SAndroid Build Coastguard Worker Layer* const input = graph.AddLayer<InputLayer>(0, "input");
788*89c4ff92SAndroid Build Coastguard Worker Layer* const outputStateIn = graph.AddLayer<InputLayer>(1, "outputStateIn");
789*89c4ff92SAndroid Build Coastguard Worker Layer* const cellStateIn = graph.AddLayer<InputLayer>(2, "cellStateIn");
790*89c4ff92SAndroid Build Coastguard Worker Layer* const scratchBuffer = graph.AddLayer<OutputLayer>(0, "scratchBuffer");
791*89c4ff92SAndroid Build Coastguard Worker Layer* const outputStateOut = graph.AddLayer<OutputLayer>(1, "outputStateOut");
792*89c4ff92SAndroid Build Coastguard Worker Layer* const cellStateOut = graph.AddLayer<OutputLayer>(2, "cellStateOut");
793*89c4ff92SAndroid Build Coastguard Worker Layer* const output = graph.AddLayer<OutputLayer>(3, "output");
794*89c4ff92SAndroid Build Coastguard Worker
795*89c4ff92SAndroid Build Coastguard Worker // connect up
796*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo lstmTensorInfo1({ batchSize, inputSize }, DataType::Float32);
797*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo lstmTensorInfo2({ batchSize, numUnits}, DataType::Float32);
798*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo lstmTensorInfo3({ batchSize, outputSize }, DataType::Float32);
799*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo lstmTensorInfoScratchBuff({ batchSize, numUnits * (layerDesc.m_CifgEnabled ? 3 : 4) },
800*89c4ff92SAndroid Build Coastguard Worker DataType::Float32);
801*89c4ff92SAndroid Build Coastguard Worker Connect(input, layer, lstmTensorInfo1, 0, 0);
802*89c4ff92SAndroid Build Coastguard Worker Connect(cellStateIn, layer, lstmTensorInfo2, 0, 1);
803*89c4ff92SAndroid Build Coastguard Worker Connect(outputStateIn, layer, lstmTensorInfo3, 0, 2);
804*89c4ff92SAndroid Build Coastguard Worker Connect(layer, scratchBuffer, lstmTensorInfoScratchBuff, 0, 0);
805*89c4ff92SAndroid Build Coastguard Worker Connect(layer, outputStateOut, lstmTensorInfo3, 1, 0);
806*89c4ff92SAndroid Build Coastguard Worker Connect(layer, cellStateOut, lstmTensorInfo2, 2, 0);
807*89c4ff92SAndroid Build Coastguard Worker Connect(layer, output, lstmTensorInfo3, 3, 0);
808*89c4ff92SAndroid Build Coastguard Worker
809*89c4ff92SAndroid Build Coastguard Worker CreateTensorHandles(graph, factory);
810*89c4ff92SAndroid Build Coastguard Worker
811*89c4ff92SAndroid Build Coastguard Worker // make the workload and check it
812*89c4ff92SAndroid Build Coastguard Worker auto workload = MakeAndCheckWorkload<LstmWorkload>(*layer, factory);
813*89c4ff92SAndroid Build Coastguard Worker LstmQueueDescriptor queueDescriptor = workload->GetData();
814*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Parameters.m_ActivationFunc == 4);
815*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Parameters.m_ClippingThresCell == 0.0f);
816*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Parameters.m_ClippingThresProj == 0.0f);
817*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Inputs.size() == 3);
818*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Outputs.size() == 4);
819*89c4ff92SAndroid Build Coastguard Worker
820*89c4ff92SAndroid Build Coastguard Worker CHECK((queueDescriptor.m_InputToForgetWeights->GetTensorInfo() == TensorInfo({ numUnits, inputSize },
821*89c4ff92SAndroid Build Coastguard Worker DataType::Float32)));
822*89c4ff92SAndroid Build Coastguard Worker CHECK((queueDescriptor.m_OutputGateBias->GetTensorInfo() == TensorInfo({ numUnits },
823*89c4ff92SAndroid Build Coastguard Worker DataType::Float32)));
824*89c4ff92SAndroid Build Coastguard Worker CHECK((queueDescriptor.m_CellBias->GetTensorInfo() == TensorInfo({ numUnits }, DataType::Float32)));
825*89c4ff92SAndroid Build Coastguard Worker return workload;
826*89c4ff92SAndroid Build Coastguard Worker }
827*89c4ff92SAndroid Build Coastguard Worker
828*89c4ff92SAndroid Build Coastguard Worker template <typename QuantizedLstmWorkload>
CreateQuantizedLstmWorkloadTest(armnn::IWorkloadFactory & factory,armnn::Graph & graph)829*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<QuantizedLstmWorkload> CreateQuantizedLstmWorkloadTest(armnn::IWorkloadFactory& factory,
830*89c4ff92SAndroid Build Coastguard Worker armnn::Graph& graph)
831*89c4ff92SAndroid Build Coastguard Worker {
832*89c4ff92SAndroid Build Coastguard Worker auto layer = graph.AddLayer<QuantizedLstmLayer>("quantizedLstmlayer");
833*89c4ff92SAndroid Build Coastguard Worker unsigned int numBatches = 2;
834*89c4ff92SAndroid Build Coastguard Worker unsigned int inputSize = 2;
835*89c4ff92SAndroid Build Coastguard Worker unsigned int outputSize = 4;
836*89c4ff92SAndroid Build Coastguard Worker
837*89c4ff92SAndroid Build Coastguard Worker // Scale/Offset for input/output, cellState In/Out, weights, bias
838*89c4ff92SAndroid Build Coastguard Worker float inputOutputScale = 0.0078125f;
839*89c4ff92SAndroid Build Coastguard Worker int32_t inputOutputOffset = 128;
840*89c4ff92SAndroid Build Coastguard Worker
841*89c4ff92SAndroid Build Coastguard Worker float cellStateScale = 0.00048828125f;
842*89c4ff92SAndroid Build Coastguard Worker int32_t cellStateOffset = 0;
843*89c4ff92SAndroid Build Coastguard Worker
844*89c4ff92SAndroid Build Coastguard Worker float weightsScale = 0.00408021f;
845*89c4ff92SAndroid Build Coastguard Worker int32_t weightsOffset = 100;
846*89c4ff92SAndroid Build Coastguard Worker
847*89c4ff92SAndroid Build Coastguard Worker float biasScale = 3.1876640625e-05f;
848*89c4ff92SAndroid Build Coastguard Worker int32_t biasOffset = 0;
849*89c4ff92SAndroid Build Coastguard Worker
850*89c4ff92SAndroid Build Coastguard Worker // Weights and bias tensor and quantization info
851*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputWeightsInfo({outputSize, inputSize},
852*89c4ff92SAndroid Build Coastguard Worker armnn::DataType::QAsymmU8,
853*89c4ff92SAndroid Build Coastguard Worker weightsScale,
854*89c4ff92SAndroid Build Coastguard Worker weightsOffset);
855*89c4ff92SAndroid Build Coastguard Worker
856*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo recurrentWeightsInfo({outputSize, outputSize},
857*89c4ff92SAndroid Build Coastguard Worker armnn::DataType::QAsymmU8,
858*89c4ff92SAndroid Build Coastguard Worker weightsScale,
859*89c4ff92SAndroid Build Coastguard Worker weightsOffset);
860*89c4ff92SAndroid Build Coastguard Worker
861*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo biasInfo({outputSize},
862*89c4ff92SAndroid Build Coastguard Worker armnn::DataType::Signed32,
863*89c4ff92SAndroid Build Coastguard Worker biasScale,
864*89c4ff92SAndroid Build Coastguard Worker biasOffset);
865*89c4ff92SAndroid Build Coastguard Worker
866*89c4ff92SAndroid Build Coastguard Worker // Weights and bias
867*89c4ff92SAndroid Build Coastguard Worker layer->m_QuantizedLstmParameters.m_InputToInputWeights =
868*89c4ff92SAndroid Build Coastguard Worker std::make_unique<ScopedTensorHandle>(inputWeightsInfo);
869*89c4ff92SAndroid Build Coastguard Worker layer->m_QuantizedLstmParameters.m_InputToForgetWeights =
870*89c4ff92SAndroid Build Coastguard Worker std::make_unique<ScopedTensorHandle>(inputWeightsInfo);
871*89c4ff92SAndroid Build Coastguard Worker layer->m_QuantizedLstmParameters.m_InputToCellWeights =
872*89c4ff92SAndroid Build Coastguard Worker std::make_unique<ScopedTensorHandle>(inputWeightsInfo);
873*89c4ff92SAndroid Build Coastguard Worker layer->m_QuantizedLstmParameters.m_InputToOutputWeights =
874*89c4ff92SAndroid Build Coastguard Worker std::make_unique<ScopedTensorHandle>(inputWeightsInfo);
875*89c4ff92SAndroid Build Coastguard Worker
876*89c4ff92SAndroid Build Coastguard Worker layer->m_QuantizedLstmParameters.m_RecurrentToInputWeights =
877*89c4ff92SAndroid Build Coastguard Worker std::make_unique<ScopedTensorHandle>(recurrentWeightsInfo);
878*89c4ff92SAndroid Build Coastguard Worker layer->m_QuantizedLstmParameters.m_RecurrentToForgetWeights =
879*89c4ff92SAndroid Build Coastguard Worker std::make_unique<ScopedTensorHandle>(recurrentWeightsInfo);
880*89c4ff92SAndroid Build Coastguard Worker layer->m_QuantizedLstmParameters.m_RecurrentToCellWeights =
881*89c4ff92SAndroid Build Coastguard Worker std::make_unique<ScopedTensorHandle>(recurrentWeightsInfo);
882*89c4ff92SAndroid Build Coastguard Worker layer->m_QuantizedLstmParameters.m_RecurrentToOutputWeights =
883*89c4ff92SAndroid Build Coastguard Worker std::make_unique<ScopedTensorHandle>(recurrentWeightsInfo);
884*89c4ff92SAndroid Build Coastguard Worker
885*89c4ff92SAndroid Build Coastguard Worker layer->m_QuantizedLstmParameters.m_InputGateBias = std::make_unique<ScopedTensorHandle>(biasInfo);
886*89c4ff92SAndroid Build Coastguard Worker layer->m_QuantizedLstmParameters.m_ForgetGateBias = std::make_unique<ScopedTensorHandle>(biasInfo);
887*89c4ff92SAndroid Build Coastguard Worker layer->m_QuantizedLstmParameters.m_CellBias = std::make_unique<ScopedTensorHandle>(biasInfo);
888*89c4ff92SAndroid Build Coastguard Worker layer->m_QuantizedLstmParameters.m_OutputGateBias = std::make_unique<ScopedTensorHandle>(biasInfo);
889*89c4ff92SAndroid Build Coastguard Worker
890*89c4ff92SAndroid Build Coastguard Worker // Allocate weights and bias
891*89c4ff92SAndroid Build Coastguard Worker layer->m_QuantizedLstmParameters.m_InputToInputWeights->Allocate();
892*89c4ff92SAndroid Build Coastguard Worker layer->m_QuantizedLstmParameters.m_InputToForgetWeights->Allocate();
893*89c4ff92SAndroid Build Coastguard Worker layer->m_QuantizedLstmParameters.m_InputToCellWeights->Allocate();
894*89c4ff92SAndroid Build Coastguard Worker layer->m_QuantizedLstmParameters.m_InputToOutputWeights->Allocate();
895*89c4ff92SAndroid Build Coastguard Worker
896*89c4ff92SAndroid Build Coastguard Worker layer->m_QuantizedLstmParameters.m_RecurrentToInputWeights->Allocate();
897*89c4ff92SAndroid Build Coastguard Worker layer->m_QuantizedLstmParameters.m_RecurrentToForgetWeights->Allocate();
898*89c4ff92SAndroid Build Coastguard Worker layer->m_QuantizedLstmParameters.m_RecurrentToCellWeights->Allocate();
899*89c4ff92SAndroid Build Coastguard Worker layer->m_QuantizedLstmParameters.m_RecurrentToOutputWeights->Allocate();
900*89c4ff92SAndroid Build Coastguard Worker
901*89c4ff92SAndroid Build Coastguard Worker layer->m_QuantizedLstmParameters.m_InputGateBias->Allocate();
902*89c4ff92SAndroid Build Coastguard Worker layer->m_QuantizedLstmParameters.m_ForgetGateBias->Allocate();
903*89c4ff92SAndroid Build Coastguard Worker layer->m_QuantizedLstmParameters.m_CellBias->Allocate();
904*89c4ff92SAndroid Build Coastguard Worker layer->m_QuantizedLstmParameters.m_OutputGateBias->Allocate();
905*89c4ff92SAndroid Build Coastguard Worker
906*89c4ff92SAndroid Build Coastguard Worker // Create input and output layers
907*89c4ff92SAndroid Build Coastguard Worker Layer* const input = graph.AddLayer<InputLayer>(0, "input");
908*89c4ff92SAndroid Build Coastguard Worker Layer* const cellStateIn = graph.AddLayer<InputLayer>(1, "cellStateIn");
909*89c4ff92SAndroid Build Coastguard Worker Layer* const outputStateIn = graph.AddLayer<InputLayer>(2, "outputStateIn");
910*89c4ff92SAndroid Build Coastguard Worker
911*89c4ff92SAndroid Build Coastguard Worker Layer* const cellStateOut = graph.AddLayer<OutputLayer>(0, "cellStateOut");
912*89c4ff92SAndroid Build Coastguard Worker Layer* const outputStateOut = graph.AddLayer<OutputLayer>(1, "outputStateOut");
913*89c4ff92SAndroid Build Coastguard Worker
914*89c4ff92SAndroid Build Coastguard Worker // Input/output tensor info and quantization info
915*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputInfo({numBatches , inputSize},
916*89c4ff92SAndroid Build Coastguard Worker armnn::DataType::QAsymmU8,
917*89c4ff92SAndroid Build Coastguard Worker inputOutputScale,
918*89c4ff92SAndroid Build Coastguard Worker inputOutputOffset);
919*89c4ff92SAndroid Build Coastguard Worker
920*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo cellStateInfo({numBatches , outputSize},
921*89c4ff92SAndroid Build Coastguard Worker armnn::DataType::QSymmS16,
922*89c4ff92SAndroid Build Coastguard Worker cellStateScale,
923*89c4ff92SAndroid Build Coastguard Worker cellStateOffset);
924*89c4ff92SAndroid Build Coastguard Worker
925*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputStateInfo({numBatches , outputSize},
926*89c4ff92SAndroid Build Coastguard Worker armnn::DataType::QAsymmU8,
927*89c4ff92SAndroid Build Coastguard Worker inputOutputScale,
928*89c4ff92SAndroid Build Coastguard Worker inputOutputOffset);
929*89c4ff92SAndroid Build Coastguard Worker
930*89c4ff92SAndroid Build Coastguard Worker // Connect input/output slots
931*89c4ff92SAndroid Build Coastguard Worker Connect(input, layer, inputInfo, 0, 0);
932*89c4ff92SAndroid Build Coastguard Worker Connect(cellStateIn, layer, cellStateInfo, 0, 1);
933*89c4ff92SAndroid Build Coastguard Worker Connect(outputStateIn, layer, outputStateInfo, 0, 2);
934*89c4ff92SAndroid Build Coastguard Worker
935*89c4ff92SAndroid Build Coastguard Worker Connect(layer, cellStateOut, cellStateInfo, 0, 0);
936*89c4ff92SAndroid Build Coastguard Worker Connect(layer, outputStateOut, outputStateInfo, 1, 0);
937*89c4ff92SAndroid Build Coastguard Worker
938*89c4ff92SAndroid Build Coastguard Worker CreateTensorHandles(graph, factory);
939*89c4ff92SAndroid Build Coastguard Worker
940*89c4ff92SAndroid Build Coastguard Worker // Create workload and check layer support
941*89c4ff92SAndroid Build Coastguard Worker auto workload = MakeAndCheckWorkload<QuantizedLstmWorkload>(*layer, factory);
942*89c4ff92SAndroid Build Coastguard Worker QuantizedLstmQueueDescriptor queueDescriptor = workload->GetData();
943*89c4ff92SAndroid Build Coastguard Worker
944*89c4ff92SAndroid Build Coastguard Worker // Validate input/output sizes
945*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Inputs.size() == 3);
946*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Outputs.size() == 2);
947*89c4ff92SAndroid Build Coastguard Worker
948*89c4ff92SAndroid Build Coastguard Worker // Validate weight tensor info
949*89c4ff92SAndroid Build Coastguard Worker CHECK((queueDescriptor.m_InputToInputWeights->GetTensorInfo() == inputWeightsInfo));
950*89c4ff92SAndroid Build Coastguard Worker CHECK((queueDescriptor.m_InputToForgetWeights->GetTensorInfo() == inputWeightsInfo));
951*89c4ff92SAndroid Build Coastguard Worker CHECK((queueDescriptor.m_InputToCellWeights->GetTensorInfo() == inputWeightsInfo));
952*89c4ff92SAndroid Build Coastguard Worker CHECK((queueDescriptor.m_InputToOutputWeights->GetTensorInfo() == inputWeightsInfo));
953*89c4ff92SAndroid Build Coastguard Worker
954*89c4ff92SAndroid Build Coastguard Worker CHECK((queueDescriptor.m_RecurrentToInputWeights->GetTensorInfo() == recurrentWeightsInfo));
955*89c4ff92SAndroid Build Coastguard Worker CHECK((queueDescriptor.m_RecurrentToForgetWeights->GetTensorInfo() == recurrentWeightsInfo));
956*89c4ff92SAndroid Build Coastguard Worker CHECK((queueDescriptor.m_RecurrentToCellWeights->GetTensorInfo() == recurrentWeightsInfo));
957*89c4ff92SAndroid Build Coastguard Worker CHECK((queueDescriptor.m_RecurrentToOutputWeights->GetTensorInfo() == recurrentWeightsInfo));
958*89c4ff92SAndroid Build Coastguard Worker
959*89c4ff92SAndroid Build Coastguard Worker CHECK((queueDescriptor.m_InputGateBias->GetTensorInfo() == biasInfo));
960*89c4ff92SAndroid Build Coastguard Worker CHECK((queueDescriptor.m_ForgetGateBias->GetTensorInfo() == biasInfo));
961*89c4ff92SAndroid Build Coastguard Worker CHECK((queueDescriptor.m_CellBias->GetTensorInfo() == biasInfo));
962*89c4ff92SAndroid Build Coastguard Worker CHECK((queueDescriptor.m_OutputGateBias->GetTensorInfo() == biasInfo));
963*89c4ff92SAndroid Build Coastguard Worker
964*89c4ff92SAndroid Build Coastguard Worker return workload;
965*89c4ff92SAndroid Build Coastguard Worker }
966*89c4ff92SAndroid Build Coastguard Worker
967*89c4ff92SAndroid Build Coastguard Worker template <typename QLstmWorkload>
CreateQLstmWorkloadTest(armnn::IWorkloadFactory & factory,armnn::Graph & graph)968*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<QLstmWorkload> CreateQLstmWorkloadTest(armnn::IWorkloadFactory& factory,
969*89c4ff92SAndroid Build Coastguard Worker armnn::Graph& graph)
970*89c4ff92SAndroid Build Coastguard Worker {
971*89c4ff92SAndroid Build Coastguard Worker QLstmDescriptor layerDesc;
972*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_CifgEnabled = true;
973*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_PeepholeEnabled = false;
974*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_ProjectionEnabled = false;
975*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_LayerNormEnabled = true;
976*89c4ff92SAndroid Build Coastguard Worker
977*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_CellClip = 0.0f;
978*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_ProjectionClip = 0.0f;
979*89c4ff92SAndroid Build Coastguard Worker
980*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_HiddenStateZeroPoint = 0;
981*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_HiddenStateScale = 0.007f;
982*89c4ff92SAndroid Build Coastguard Worker
983*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_InputIntermediateScale = 0.007059f;
984*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_ForgetIntermediateScale = 0.007812f;
985*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_CellIntermediateScale = 0.007059f;
986*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_OutputIntermediateScale = 0.007812f;
987*89c4ff92SAndroid Build Coastguard Worker
988*89c4ff92SAndroid Build Coastguard Worker QLstmLayer* const layer = graph.AddLayer<QLstmLayer>(layerDesc, "qLstm");
989*89c4ff92SAndroid Build Coastguard Worker
990*89c4ff92SAndroid Build Coastguard Worker unsigned int numBatches = 2;
991*89c4ff92SAndroid Build Coastguard Worker unsigned int inputSize = 4;
992*89c4ff92SAndroid Build Coastguard Worker unsigned int numUnits = 4;
993*89c4ff92SAndroid Build Coastguard Worker unsigned int outputSize = 4;
994*89c4ff92SAndroid Build Coastguard Worker
995*89c4ff92SAndroid Build Coastguard Worker // Scale/Offset quantization info
996*89c4ff92SAndroid Build Coastguard Worker float inputScale = 0.0078125f;
997*89c4ff92SAndroid Build Coastguard Worker int32_t inputOffset = 0;
998*89c4ff92SAndroid Build Coastguard Worker
999*89c4ff92SAndroid Build Coastguard Worker // if (!projectionEnabled) outputScale == hiddenStateScale
1000*89c4ff92SAndroid Build Coastguard Worker float outputScale = layerDesc.m_HiddenStateScale;
1001*89c4ff92SAndroid Build Coastguard Worker int32_t outputOffset = layerDesc.m_HiddenStateZeroPoint;
1002*89c4ff92SAndroid Build Coastguard Worker
1003*89c4ff92SAndroid Build Coastguard Worker float cellStateScale = 3.05176e-05f;
1004*89c4ff92SAndroid Build Coastguard Worker int32_t cellStateOffset = 0;
1005*89c4ff92SAndroid Build Coastguard Worker
1006*89c4ff92SAndroid Build Coastguard Worker float weightsScale = 0.00784314f;
1007*89c4ff92SAndroid Build Coastguard Worker int32_t weightsOffset = 0;
1008*89c4ff92SAndroid Build Coastguard Worker
1009*89c4ff92SAndroid Build Coastguard Worker float layerNormScale = 3.05182e-05f;
1010*89c4ff92SAndroid Build Coastguard Worker int32_t layerNormOffset = 0;
1011*89c4ff92SAndroid Build Coastguard Worker
1012*89c4ff92SAndroid Build Coastguard Worker float biasScale = layerNormScale / 1024;
1013*89c4ff92SAndroid Build Coastguard Worker int32_t biasOffset = 0;
1014*89c4ff92SAndroid Build Coastguard Worker
1015*89c4ff92SAndroid Build Coastguard Worker // Weights and bias tensor and quantization info
1016*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputWeightsInfo({outputSize, inputSize},
1017*89c4ff92SAndroid Build Coastguard Worker armnn::DataType::QSymmS8,
1018*89c4ff92SAndroid Build Coastguard Worker weightsScale,
1019*89c4ff92SAndroid Build Coastguard Worker weightsOffset);
1020*89c4ff92SAndroid Build Coastguard Worker
1021*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo recurrentWeightsInfo({outputSize, outputSize},
1022*89c4ff92SAndroid Build Coastguard Worker armnn::DataType::QSymmS8,
1023*89c4ff92SAndroid Build Coastguard Worker weightsScale,
1024*89c4ff92SAndroid Build Coastguard Worker weightsOffset);
1025*89c4ff92SAndroid Build Coastguard Worker
1026*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo biasInfo({outputSize}, armnn::DataType::Signed32, biasScale, biasOffset);
1027*89c4ff92SAndroid Build Coastguard Worker
1028*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo layerNormWeightsInfo({numUnits}, armnn::DataType::QSymmS16, layerNormScale, layerNormOffset);
1029*89c4ff92SAndroid Build Coastguard Worker
1030*89c4ff92SAndroid Build Coastguard Worker // Create and allocate tensors
1031*89c4ff92SAndroid Build Coastguard Worker layer->m_BasicParameters.m_InputToForgetWeights = std::make_unique<ScopedTensorHandle>(inputWeightsInfo);
1032*89c4ff92SAndroid Build Coastguard Worker layer->m_BasicParameters.m_InputToCellWeights = std::make_unique<ScopedTensorHandle>(inputWeightsInfo);
1033*89c4ff92SAndroid Build Coastguard Worker layer->m_BasicParameters.m_InputToOutputWeights = std::make_unique<ScopedTensorHandle>(inputWeightsInfo);
1034*89c4ff92SAndroid Build Coastguard Worker
1035*89c4ff92SAndroid Build Coastguard Worker layer->m_BasicParameters.m_RecurrentToForgetWeights =
1036*89c4ff92SAndroid Build Coastguard Worker std::make_unique<ScopedTensorHandle>(recurrentWeightsInfo);
1037*89c4ff92SAndroid Build Coastguard Worker layer->m_BasicParameters.m_RecurrentToCellWeights =
1038*89c4ff92SAndroid Build Coastguard Worker std::make_unique<ScopedTensorHandle>(recurrentWeightsInfo);
1039*89c4ff92SAndroid Build Coastguard Worker layer->m_BasicParameters.m_RecurrentToOutputWeights =
1040*89c4ff92SAndroid Build Coastguard Worker std::make_unique<ScopedTensorHandle>(recurrentWeightsInfo);
1041*89c4ff92SAndroid Build Coastguard Worker
1042*89c4ff92SAndroid Build Coastguard Worker layer->m_BasicParameters.m_ForgetGateBias = std::make_unique<ScopedTensorHandle>(biasInfo);
1043*89c4ff92SAndroid Build Coastguard Worker layer->m_BasicParameters.m_CellBias = std::make_unique<ScopedTensorHandle>(biasInfo);
1044*89c4ff92SAndroid Build Coastguard Worker layer->m_BasicParameters.m_OutputGateBias = std::make_unique<ScopedTensorHandle>(biasInfo);
1045*89c4ff92SAndroid Build Coastguard Worker
1046*89c4ff92SAndroid Build Coastguard Worker layer->m_LayerNormParameters.m_ForgetLayerNormWeights =
1047*89c4ff92SAndroid Build Coastguard Worker std::make_unique<ScopedTensorHandle>(layerNormWeightsInfo);
1048*89c4ff92SAndroid Build Coastguard Worker layer->m_LayerNormParameters.m_CellLayerNormWeights =
1049*89c4ff92SAndroid Build Coastguard Worker std::make_unique<ScopedTensorHandle>(layerNormWeightsInfo);
1050*89c4ff92SAndroid Build Coastguard Worker layer->m_LayerNormParameters.m_OutputLayerNormWeights =
1051*89c4ff92SAndroid Build Coastguard Worker std::make_unique<ScopedTensorHandle>(layerNormWeightsInfo);
1052*89c4ff92SAndroid Build Coastguard Worker
1053*89c4ff92SAndroid Build Coastguard Worker layer->m_BasicParameters.m_InputToForgetWeights->Allocate();
1054*89c4ff92SAndroid Build Coastguard Worker layer->m_BasicParameters.m_InputToCellWeights->Allocate();
1055*89c4ff92SAndroid Build Coastguard Worker layer->m_BasicParameters.m_InputToOutputWeights->Allocate();
1056*89c4ff92SAndroid Build Coastguard Worker
1057*89c4ff92SAndroid Build Coastguard Worker layer->m_BasicParameters.m_RecurrentToForgetWeights->Allocate();
1058*89c4ff92SAndroid Build Coastguard Worker layer->m_BasicParameters.m_RecurrentToCellWeights->Allocate();
1059*89c4ff92SAndroid Build Coastguard Worker layer->m_BasicParameters.m_RecurrentToOutputWeights->Allocate();
1060*89c4ff92SAndroid Build Coastguard Worker
1061*89c4ff92SAndroid Build Coastguard Worker layer->m_BasicParameters.m_ForgetGateBias->Allocate();
1062*89c4ff92SAndroid Build Coastguard Worker layer->m_BasicParameters.m_CellBias->Allocate();
1063*89c4ff92SAndroid Build Coastguard Worker layer->m_BasicParameters.m_OutputGateBias->Allocate();
1064*89c4ff92SAndroid Build Coastguard Worker
1065*89c4ff92SAndroid Build Coastguard Worker layer->m_LayerNormParameters.m_ForgetLayerNormWeights->Allocate();
1066*89c4ff92SAndroid Build Coastguard Worker layer->m_LayerNormParameters.m_CellLayerNormWeights->Allocate();
1067*89c4ff92SAndroid Build Coastguard Worker layer->m_LayerNormParameters.m_OutputLayerNormWeights->Allocate();
1068*89c4ff92SAndroid Build Coastguard Worker
1069*89c4ff92SAndroid Build Coastguard Worker // Input and output layers
1070*89c4ff92SAndroid Build Coastguard Worker Layer* const input = graph.AddLayer<InputLayer>(0, "input");
1071*89c4ff92SAndroid Build Coastguard Worker Layer* const outputStateIn = graph.AddLayer<InputLayer>(1, "outputStateIn");
1072*89c4ff92SAndroid Build Coastguard Worker Layer* const cellStateIn = graph.AddLayer<InputLayer>(2, "cellStateIn");
1073*89c4ff92SAndroid Build Coastguard Worker
1074*89c4ff92SAndroid Build Coastguard Worker Layer* const outputStateOut = graph.AddLayer<OutputLayer>(0, "outputStateOut");
1075*89c4ff92SAndroid Build Coastguard Worker Layer* const cellStateOut = graph.AddLayer<OutputLayer>(1, "cellStateOut");
1076*89c4ff92SAndroid Build Coastguard Worker Layer* const output = graph.AddLayer<OutputLayer>(2, "output");
1077*89c4ff92SAndroid Build Coastguard Worker
1078*89c4ff92SAndroid Build Coastguard Worker // Input/Output tensor info
1079*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputInfo({numBatches , inputSize},
1080*89c4ff92SAndroid Build Coastguard Worker armnn::DataType::QAsymmS8,
1081*89c4ff92SAndroid Build Coastguard Worker inputScale,
1082*89c4ff92SAndroid Build Coastguard Worker inputOffset);
1083*89c4ff92SAndroid Build Coastguard Worker
1084*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo cellStateInfo({numBatches , numUnits},
1085*89c4ff92SAndroid Build Coastguard Worker armnn::DataType::QSymmS16,
1086*89c4ff92SAndroid Build Coastguard Worker cellStateScale,
1087*89c4ff92SAndroid Build Coastguard Worker cellStateOffset);
1088*89c4ff92SAndroid Build Coastguard Worker
1089*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputStateInfo({numBatches , outputSize},
1090*89c4ff92SAndroid Build Coastguard Worker armnn::DataType::QAsymmS8,
1091*89c4ff92SAndroid Build Coastguard Worker outputScale,
1092*89c4ff92SAndroid Build Coastguard Worker outputOffset);
1093*89c4ff92SAndroid Build Coastguard Worker
1094*89c4ff92SAndroid Build Coastguard Worker // Connect layers to slots
1095*89c4ff92SAndroid Build Coastguard Worker Connect(input, layer, inputInfo, 0, 0);
1096*89c4ff92SAndroid Build Coastguard Worker Connect(outputStateIn, layer, outputStateInfo, 0, 1);
1097*89c4ff92SAndroid Build Coastguard Worker Connect(cellStateIn, layer, cellStateInfo, 0, 2);
1098*89c4ff92SAndroid Build Coastguard Worker
1099*89c4ff92SAndroid Build Coastguard Worker Connect(layer, outputStateOut, outputStateInfo, 0, 0);
1100*89c4ff92SAndroid Build Coastguard Worker Connect(layer, cellStateOut, cellStateInfo, 1, 0);
1101*89c4ff92SAndroid Build Coastguard Worker Connect(layer, output, outputStateInfo, 2, 0);
1102*89c4ff92SAndroid Build Coastguard Worker
1103*89c4ff92SAndroid Build Coastguard Worker CreateTensorHandles(graph, factory);
1104*89c4ff92SAndroid Build Coastguard Worker
1105*89c4ff92SAndroid Build Coastguard Worker // Create and check workload
1106*89c4ff92SAndroid Build Coastguard Worker auto workload = MakeAndCheckWorkload<QLstmWorkload>(*layer, factory);
1107*89c4ff92SAndroid Build Coastguard Worker QLstmQueueDescriptor queueDescriptor = workload->GetData();
1108*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Parameters.m_CellClip == 0.0f);
1109*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Parameters.m_ProjectionClip == 0.0f);
1110*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Inputs.size() == 3);
1111*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Outputs.size() == 3);
1112*89c4ff92SAndroid Build Coastguard Worker
1113*89c4ff92SAndroid Build Coastguard Worker CHECK((queueDescriptor.m_InputToForgetWeights->GetTensorInfo() == inputWeightsInfo));
1114*89c4ff92SAndroid Build Coastguard Worker CHECK((queueDescriptor.m_InputToCellWeights->GetTensorInfo() == inputWeightsInfo));
1115*89c4ff92SAndroid Build Coastguard Worker CHECK((queueDescriptor.m_InputToOutputWeights->GetTensorInfo() == inputWeightsInfo));
1116*89c4ff92SAndroid Build Coastguard Worker
1117*89c4ff92SAndroid Build Coastguard Worker CHECK((queueDescriptor.m_RecurrentToForgetWeights->GetTensorInfo() == recurrentWeightsInfo));
1118*89c4ff92SAndroid Build Coastguard Worker CHECK((queueDescriptor.m_RecurrentToCellWeights->GetTensorInfo() == recurrentWeightsInfo));
1119*89c4ff92SAndroid Build Coastguard Worker CHECK((queueDescriptor.m_RecurrentToOutputWeights->GetTensorInfo() == recurrentWeightsInfo));
1120*89c4ff92SAndroid Build Coastguard Worker
1121*89c4ff92SAndroid Build Coastguard Worker CHECK((queueDescriptor.m_ForgetGateBias->GetTensorInfo() == biasInfo));
1122*89c4ff92SAndroid Build Coastguard Worker CHECK((queueDescriptor.m_CellBias->GetTensorInfo() == biasInfo));
1123*89c4ff92SAndroid Build Coastguard Worker CHECK((queueDescriptor.m_OutputGateBias->GetTensorInfo() == biasInfo));
1124*89c4ff92SAndroid Build Coastguard Worker
1125*89c4ff92SAndroid Build Coastguard Worker return workload;
1126*89c4ff92SAndroid Build Coastguard Worker }
1127*89c4ff92SAndroid Build Coastguard Worker
1128*89c4ff92SAndroid Build Coastguard Worker template<typename Convolution2dWorkload, armnn::DataType DataType>
CreateDirectConvolution2dWorkloadTest(armnn::IWorkloadFactory & factory,armnn::Graph & graph)1129*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<Convolution2dWorkload> CreateDirectConvolution2dWorkloadTest(armnn::IWorkloadFactory& factory,
1130*89c4ff92SAndroid Build Coastguard Worker armnn::Graph& graph)
1131*89c4ff92SAndroid Build Coastguard Worker {
1132*89c4ff92SAndroid Build Coastguard Worker // Creates the layer we're testing.
1133*89c4ff92SAndroid Build Coastguard Worker Convolution2dDescriptor layerDesc;
1134*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_PadLeft = 1;
1135*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_PadRight = 1;
1136*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_PadTop = 1;
1137*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_PadBottom = 1;
1138*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_StrideX = 1;
1139*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_StrideY = 1;
1140*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_BiasEnabled = true;
1141*89c4ff92SAndroid Build Coastguard Worker
1142*89c4ff92SAndroid Build Coastguard Worker Convolution2dLayer* const layer = graph.AddLayer<Convolution2dLayer>(layerDesc, "layer");
1143*89c4ff92SAndroid Build Coastguard Worker
1144*89c4ff92SAndroid Build Coastguard Worker float inputsQScale = 1.0f;
1145*89c4ff92SAndroid Build Coastguard Worker float outputQScale = DataType == armnn::DataType::QAsymmU8 ? 2.0f : 1.0;
1146*89c4ff92SAndroid Build Coastguard Worker
1147*89c4ff92SAndroid Build Coastguard Worker TensorShape biasShape = TensorShape{ 2 };
1148*89c4ff92SAndroid Build Coastguard Worker TensorShape weightShape = TensorShape{ 2, 3, 3, 3 };
1149*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo weightsTensorInfo(weightShape, DataType, inputsQScale);
1150*89c4ff92SAndroid Build Coastguard Worker weightsTensorInfo.SetConstant();
1151*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo biasTensorInfo(biasShape, GetBiasDataType(DataType), inputsQScale);
1152*89c4ff92SAndroid Build Coastguard Worker biasTensorInfo.SetConstant();
1153*89c4ff92SAndroid Build Coastguard Worker
1154*89c4ff92SAndroid Build Coastguard Worker // Creates extra layers.
1155*89c4ff92SAndroid Build Coastguard Worker Layer* const input = graph.AddLayer<InputLayer>(0, "input");
1156*89c4ff92SAndroid Build Coastguard Worker auto const weights = graph.AddLayer<ConstantLayer>("weights");
1157*89c4ff92SAndroid Build Coastguard Worker auto const bias = graph.AddLayer<ConstantLayer>("bias");
1158*89c4ff92SAndroid Build Coastguard Worker Layer* const output = graph.AddLayer<OutputLayer>(0, "output");
1159*89c4ff92SAndroid Build Coastguard Worker
1160*89c4ff92SAndroid Build Coastguard Worker weights->m_LayerOutput = std::make_unique<ScopedTensorHandle>(weightsTensorInfo);
1161*89c4ff92SAndroid Build Coastguard Worker weights->m_LayerOutput->Allocate();
1162*89c4ff92SAndroid Build Coastguard Worker bias->m_LayerOutput = std::make_unique<ScopedTensorHandle>(biasTensorInfo);
1163*89c4ff92SAndroid Build Coastguard Worker bias->m_LayerOutput->Allocate();
1164*89c4ff92SAndroid Build Coastguard Worker
1165*89c4ff92SAndroid Build Coastguard Worker // Connects up.
1166*89c4ff92SAndroid Build Coastguard Worker Connect(input, layer, TensorInfo({2, 3, 6, 6}, DataType, inputsQScale));
1167*89c4ff92SAndroid Build Coastguard Worker Connect(weights, layer, weightsTensorInfo, 0, 1);
1168*89c4ff92SAndroid Build Coastguard Worker Connect(bias, layer, biasTensorInfo, 0, 2);
1169*89c4ff92SAndroid Build Coastguard Worker Connect(layer, output, TensorInfo({2, 2, 6, 6}, DataType, outputQScale));
1170*89c4ff92SAndroid Build Coastguard Worker CreateTensorHandles(graph, factory);
1171*89c4ff92SAndroid Build Coastguard Worker
1172*89c4ff92SAndroid Build Coastguard Worker // Makes the workload and checks it.
1173*89c4ff92SAndroid Build Coastguard Worker auto workload = MakeAndCheckWorkload<Convolution2dWorkload>(*layer, factory);
1174*89c4ff92SAndroid Build Coastguard Worker
1175*89c4ff92SAndroid Build Coastguard Worker Convolution2dQueueDescriptor queueDescriptor = workload->GetData();
1176*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Parameters.m_StrideX == 1);
1177*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Parameters.m_StrideY == 1);
1178*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Parameters.m_PadLeft == 1);
1179*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Parameters.m_PadRight == 1);
1180*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Parameters.m_PadTop == 1);
1181*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Parameters.m_PadBottom == 1);
1182*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Parameters.m_BiasEnabled == true);
1183*89c4ff92SAndroid Build Coastguard Worker
1184*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Inputs.size() == 3);
1185*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Outputs.size() == 1);
1186*89c4ff92SAndroid Build Coastguard Worker
1187*89c4ff92SAndroid Build Coastguard Worker // Returns so we can do extra, backend-specific tests.
1188*89c4ff92SAndroid Build Coastguard Worker return workload;
1189*89c4ff92SAndroid Build Coastguard Worker }
1190*89c4ff92SAndroid Build Coastguard Worker
1191*89c4ff92SAndroid Build Coastguard Worker template <typename DepthwiseConvolution2dFloat32Workload, armnn::DataType DataType>
CreateDepthwiseConvolution2dWorkloadTest(armnn::IWorkloadFactory & factory,armnn::Graph & graph,DataLayout dataLayout=DataLayout::NCHW)1192*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<DepthwiseConvolution2dFloat32Workload> CreateDepthwiseConvolution2dWorkloadTest(
1193*89c4ff92SAndroid Build Coastguard Worker armnn::IWorkloadFactory& factory, armnn::Graph& graph, DataLayout dataLayout = DataLayout::NCHW)
1194*89c4ff92SAndroid Build Coastguard Worker {
1195*89c4ff92SAndroid Build Coastguard Worker // Creates the layer we're testing.
1196*89c4ff92SAndroid Build Coastguard Worker DepthwiseConvolution2dDescriptor layerDesc;
1197*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_PadLeft = 1;
1198*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_PadRight = 2;
1199*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_PadTop = 1;
1200*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_PadBottom = 2;
1201*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_StrideX = 1;
1202*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_StrideY = 1;
1203*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_BiasEnabled = false;
1204*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_DataLayout = dataLayout;
1205*89c4ff92SAndroid Build Coastguard Worker
1206*89c4ff92SAndroid Build Coastguard Worker float inputsQScale = 1.0f;
1207*89c4ff92SAndroid Build Coastguard Worker float outputQScale = DataType == armnn::DataType::QAsymmU8 ? 2.0f : 1.0;
1208*89c4ff92SAndroid Build Coastguard Worker
1209*89c4ff92SAndroid Build Coastguard Worker TensorShape weightShape({1, 4, 4, 2});
1210*89c4ff92SAndroid Build Coastguard Worker TensorShape inputShape = (dataLayout == DataLayout::NCHW) ?
1211*89c4ff92SAndroid Build Coastguard Worker TensorShape{ 2, 2, 5, 5 } : TensorShape{ 2, 5, 5, 2 };
1212*89c4ff92SAndroid Build Coastguard Worker TensorShape outputShape = (dataLayout == DataLayout::NCHW) ?
1213*89c4ff92SAndroid Build Coastguard Worker TensorShape{ 2, 2, 5, 5 } : TensorShape{ 2, 5, 5, 2 };
1214*89c4ff92SAndroid Build Coastguard Worker
1215*89c4ff92SAndroid Build Coastguard Worker DepthwiseConvolution2dLayer* const layer = graph.AddLayer<DepthwiseConvolution2dLayer>(layerDesc, "layer");
1216*89c4ff92SAndroid Build Coastguard Worker
1217*89c4ff92SAndroid Build Coastguard Worker
1218*89c4ff92SAndroid Build Coastguard Worker // Creates extra layers.
1219*89c4ff92SAndroid Build Coastguard Worker Layer* const input = graph.AddLayer<InputLayer>(0, "input");
1220*89c4ff92SAndroid Build Coastguard Worker Layer* const weights = graph.AddLayer<ConstantLayer>("weights");
1221*89c4ff92SAndroid Build Coastguard Worker Layer* const output = graph.AddLayer<OutputLayer>(0, "output");
1222*89c4ff92SAndroid Build Coastguard Worker
1223*89c4ff92SAndroid Build Coastguard Worker // Connects up.
1224*89c4ff92SAndroid Build Coastguard Worker Connect(input, layer, TensorInfo(inputShape, DataType, inputsQScale));
1225*89c4ff92SAndroid Build Coastguard Worker Connect(weights, layer, TensorInfo(weightShape, DataType, inputsQScale, 0.0f, true), 0, 1);
1226*89c4ff92SAndroid Build Coastguard Worker Connect(layer, output, TensorInfo(outputShape, DataType, outputQScale));
1227*89c4ff92SAndroid Build Coastguard Worker CreateTensorHandles(graph, factory);
1228*89c4ff92SAndroid Build Coastguard Worker
1229*89c4ff92SAndroid Build Coastguard Worker // Makes the workload and checks it.
1230*89c4ff92SAndroid Build Coastguard Worker auto workload = MakeAndCheckWorkload<DepthwiseConvolution2dFloat32Workload>(*layer, factory);
1231*89c4ff92SAndroid Build Coastguard Worker
1232*89c4ff92SAndroid Build Coastguard Worker DepthwiseConvolution2dQueueDescriptor queueDescriptor = workload->GetData();
1233*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Parameters.m_StrideX == 1);
1234*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Parameters.m_StrideY == 1);
1235*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Parameters.m_PadLeft == 1);
1236*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Parameters.m_PadRight == 2);
1237*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Parameters.m_PadTop == 1);
1238*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Parameters.m_PadBottom == 2);
1239*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Parameters.m_BiasEnabled == false);
1240*89c4ff92SAndroid Build Coastguard Worker CHECK((queueDescriptor.m_Parameters.m_DataLayout == dataLayout));
1241*89c4ff92SAndroid Build Coastguard Worker
1242*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Inputs.size() == 2);
1243*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Outputs.size() == 1);
1244*89c4ff92SAndroid Build Coastguard Worker
1245*89c4ff92SAndroid Build Coastguard Worker // Returns so we can do extra, backend-specific tests.
1246*89c4ff92SAndroid Build Coastguard Worker return workload;
1247*89c4ff92SAndroid Build Coastguard Worker }
1248*89c4ff92SAndroid Build Coastguard Worker
1249*89c4ff92SAndroid Build Coastguard Worker template <typename FullyConnectedWorkload, armnn::DataType DataType>
CreateFullyConnectedWorkloadTest(armnn::IWorkloadFactory & factory,armnn::Graph & graph)1250*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<FullyConnectedWorkload> CreateFullyConnectedWorkloadTest(armnn::IWorkloadFactory& factory,
1251*89c4ff92SAndroid Build Coastguard Worker armnn::Graph& graph)
1252*89c4ff92SAndroid Build Coastguard Worker {
1253*89c4ff92SAndroid Build Coastguard Worker // Creates the layer we're testing.
1254*89c4ff92SAndroid Build Coastguard Worker FullyConnectedDescriptor layerDesc;
1255*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_BiasEnabled = false;
1256*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_TransposeWeightMatrix = true;
1257*89c4ff92SAndroid Build Coastguard Worker
1258*89c4ff92SAndroid Build Coastguard Worker FullyConnectedLayer* const layer = graph.AddLayer<FullyConnectedLayer>(layerDesc, "layer");
1259*89c4ff92SAndroid Build Coastguard Worker
1260*89c4ff92SAndroid Build Coastguard Worker float inputsQScale = 1.0f;
1261*89c4ff92SAndroid Build Coastguard Worker float outputQScale = DataType == armnn::DataType::QAsymmU8 ? 2.0f : 1.0;
1262*89c4ff92SAndroid Build Coastguard Worker
1263*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo weightsTensorInfo({7, 20}, DataType, inputsQScale);
1264*89c4ff92SAndroid Build Coastguard Worker weightsTensorInfo.SetConstant();
1265*89c4ff92SAndroid Build Coastguard Worker
1266*89c4ff92SAndroid Build Coastguard Worker // Creates extra layers.
1267*89c4ff92SAndroid Build Coastguard Worker Layer* const input = graph.AddLayer<InputLayer>(0, "input");
1268*89c4ff92SAndroid Build Coastguard Worker auto const weights = graph.AddLayer<ConstantLayer>("weights");
1269*89c4ff92SAndroid Build Coastguard Worker Layer* const output = graph.AddLayer<OutputLayer>(0, "output");
1270*89c4ff92SAndroid Build Coastguard Worker
1271*89c4ff92SAndroid Build Coastguard Worker weights->m_LayerOutput = std::make_unique<ScopedTensorHandle>(weightsTensorInfo);
1272*89c4ff92SAndroid Build Coastguard Worker weights->m_LayerOutput->Allocate();
1273*89c4ff92SAndroid Build Coastguard Worker
1274*89c4ff92SAndroid Build Coastguard Worker // Connects up.
1275*89c4ff92SAndroid Build Coastguard Worker Connect(input, layer, TensorInfo({3, 1, 4, 5}, DataType, inputsQScale), 0, 0);
1276*89c4ff92SAndroid Build Coastguard Worker Connect(weights, layer, weightsTensorInfo, 0, 1);
1277*89c4ff92SAndroid Build Coastguard Worker Connect(layer, output, TensorInfo({3, 7}, DataType, outputQScale));
1278*89c4ff92SAndroid Build Coastguard Worker CreateTensorHandles(graph, factory);
1279*89c4ff92SAndroid Build Coastguard Worker
1280*89c4ff92SAndroid Build Coastguard Worker // Makes the workload and checks it.
1281*89c4ff92SAndroid Build Coastguard Worker auto workload = MakeAndCheckWorkload<FullyConnectedWorkload>(*layer, factory);
1282*89c4ff92SAndroid Build Coastguard Worker
1283*89c4ff92SAndroid Build Coastguard Worker FullyConnectedQueueDescriptor queueDescriptor = workload->GetData();
1284*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Parameters.m_TransposeWeightMatrix == true);
1285*89c4ff92SAndroid Build Coastguard Worker
1286*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Inputs.size() == 2);
1287*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Outputs.size() == 1);
1288*89c4ff92SAndroid Build Coastguard Worker
1289*89c4ff92SAndroid Build Coastguard Worker // Returns so we can do extra, backend-specific tests.
1290*89c4ff92SAndroid Build Coastguard Worker return workload;
1291*89c4ff92SAndroid Build Coastguard Worker }
1292*89c4ff92SAndroid Build Coastguard Worker
1293*89c4ff92SAndroid Build Coastguard Worker template <typename FullyConnectedWorkload, armnn::DataType DataType>
CreateFullyConnectedWithBlobWorkloadTest(armnn::IWorkloadFactory & factory,armnn::Graph & graph)1294*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<FullyConnectedWorkload> CreateFullyConnectedWithBlobWorkloadTest
1295*89c4ff92SAndroid Build Coastguard Worker (armnn::IWorkloadFactory& factory,
1296*89c4ff92SAndroid Build Coastguard Worker armnn::Graph& graph)
1297*89c4ff92SAndroid Build Coastguard Worker {
1298*89c4ff92SAndroid Build Coastguard Worker // Creates the layer we're testing.
1299*89c4ff92SAndroid Build Coastguard Worker FullyConnectedDescriptor layerDesc;
1300*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_BiasEnabled = true;
1301*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_TransposeWeightMatrix = true;
1302*89c4ff92SAndroid Build Coastguard Worker
1303*89c4ff92SAndroid Build Coastguard Worker FullyConnectedLayer* const layer = graph.AddLayer<FullyConnectedLayer>(layerDesc, "layer");
1304*89c4ff92SAndroid Build Coastguard Worker
1305*89c4ff92SAndroid Build Coastguard Worker float inputsQScale = 1.0f;
1306*89c4ff92SAndroid Build Coastguard Worker float outputQScale = DataType == armnn::DataType::QAsymmU8 ? 2.0f : 1.0;
1307*89c4ff92SAndroid Build Coastguard Worker
1308*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo weightsTensorInfo({7, 20}, DataType, inputsQScale);
1309*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo biasesTensorInfo({7}, GetBiasDataType(DataType), inputsQScale);
1310*89c4ff92SAndroid Build Coastguard Worker weightsTensorInfo.SetConstant();
1311*89c4ff92SAndroid Build Coastguard Worker biasesTensorInfo.SetConstant();
1312*89c4ff92SAndroid Build Coastguard Worker
1313*89c4ff92SAndroid Build Coastguard Worker auto activationDesc = std::make_shared<ActivationDescriptor>();
1314*89c4ff92SAndroid Build Coastguard Worker activationDesc->m_A = 10.0f;
1315*89c4ff92SAndroid Build Coastguard Worker activationDesc->m_B = 5.0f;
1316*89c4ff92SAndroid Build Coastguard Worker activationDesc->m_Function = armnn::ActivationFunction::BoundedReLu;
1317*89c4ff92SAndroid Build Coastguard Worker
1318*89c4ff92SAndroid Build Coastguard Worker layer->SetAdditionalInfoForObject(activationDesc);
1319*89c4ff92SAndroid Build Coastguard Worker
1320*89c4ff92SAndroid Build Coastguard Worker // Check that the additional information can be queried from the layer
1321*89c4ff92SAndroid Build Coastguard Worker std::shared_ptr<ActivationDescriptor> activationDescPtr = layer->GetAdditionalInformation<ActivationDescriptor>();
1322*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(static_cast<float>(activationDescPtr->m_A) == 10.0f);
1323*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(static_cast<float>(activationDescPtr->m_B) == 5.0f);
1324*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(static_cast<ActivationFunction>(activationDescPtr->m_Function) ==
1325*89c4ff92SAndroid Build Coastguard Worker armnn::ActivationFunction::BoundedReLu);
1326*89c4ff92SAndroid Build Coastguard Worker
1327*89c4ff92SAndroid Build Coastguard Worker // Creates extra layers.
1328*89c4ff92SAndroid Build Coastguard Worker Layer* const input = graph.AddLayer<InputLayer>(0, "input");
1329*89c4ff92SAndroid Build Coastguard Worker auto const weights = graph.AddLayer<ConstantLayer>("weights");
1330*89c4ff92SAndroid Build Coastguard Worker auto const biases = graph.AddLayer<ConstantLayer>("biases");
1331*89c4ff92SAndroid Build Coastguard Worker Layer* const output = graph.AddLayer<OutputLayer>(0, "output");
1332*89c4ff92SAndroid Build Coastguard Worker
1333*89c4ff92SAndroid Build Coastguard Worker weights->m_LayerOutput = std::make_unique<ScopedTensorHandle>(weightsTensorInfo);
1334*89c4ff92SAndroid Build Coastguard Worker weights->m_LayerOutput->Allocate();
1335*89c4ff92SAndroid Build Coastguard Worker biases->m_LayerOutput = std::make_unique<ScopedTensorHandle>(biasesTensorInfo);
1336*89c4ff92SAndroid Build Coastguard Worker biases->m_LayerOutput->Allocate();
1337*89c4ff92SAndroid Build Coastguard Worker
1338*89c4ff92SAndroid Build Coastguard Worker // Connects up.
1339*89c4ff92SAndroid Build Coastguard Worker Connect(input, layer, TensorInfo({3, 1, 4, 5}, DataType, inputsQScale), 0, 0);
1340*89c4ff92SAndroid Build Coastguard Worker Connect(weights, layer, weightsTensorInfo, 0, 1);
1341*89c4ff92SAndroid Build Coastguard Worker Connect(biases, layer, biasesTensorInfo, 0, 2);
1342*89c4ff92SAndroid Build Coastguard Worker Connect(layer, output, TensorInfo({3, 7}, DataType, outputQScale));
1343*89c4ff92SAndroid Build Coastguard Worker CreateTensorHandles(graph, factory);
1344*89c4ff92SAndroid Build Coastguard Worker
1345*89c4ff92SAndroid Build Coastguard Worker // Makes the workload and checks it.
1346*89c4ff92SAndroid Build Coastguard Worker auto workload = MakeAndCheckWorkload<FullyConnectedWorkload>(*layer, factory);
1347*89c4ff92SAndroid Build Coastguard Worker
1348*89c4ff92SAndroid Build Coastguard Worker FullyConnectedQueueDescriptor queueDescriptor = workload->GetData();
1349*89c4ff92SAndroid Build Coastguard Worker
1350*89c4ff92SAndroid Build Coastguard Worker const ActivationDescriptor* queueDescBlobPtr = queueDescriptor.GetAdditionalInformation<ActivationDescriptor>();
1351*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(queueDescBlobPtr);
1352*89c4ff92SAndroid Build Coastguard Worker
1353*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(static_cast<float>(queueDescBlobPtr->m_A) == 10.0f);
1354*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(static_cast<float>(queueDescBlobPtr->m_B) == 5.0f);
1355*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(
1356*89c4ff92SAndroid Build Coastguard Worker static_cast<ActivationFunction>(queueDescBlobPtr->m_Function) == armnn::ActivationFunction::BoundedReLu
1357*89c4ff92SAndroid Build Coastguard Worker );
1358*89c4ff92SAndroid Build Coastguard Worker
1359*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Parameters.m_BiasEnabled == true);
1360*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Parameters.m_TransposeWeightMatrix == true);
1361*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Inputs.size() == 3);
1362*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Outputs.size() == 1);
1363*89c4ff92SAndroid Build Coastguard Worker
1364*89c4ff92SAndroid Build Coastguard Worker // Returns so we can do extra, backend-specific tests.
1365*89c4ff92SAndroid Build Coastguard Worker return workload;
1366*89c4ff92SAndroid Build Coastguard Worker }
1367*89c4ff92SAndroid Build Coastguard Worker
1368*89c4ff92SAndroid Build Coastguard Worker template <typename FullyConnectedWorkload, armnn::DataType DataType>
CreateFullyConnectedWorkloadWeightsBiasesAsInputsTest(armnn::IWorkloadFactory & factory,armnn::Graph & graph)1369*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<FullyConnectedWorkload> CreateFullyConnectedWorkloadWeightsBiasesAsInputsTest
1370*89c4ff92SAndroid Build Coastguard Worker (armnn::IWorkloadFactory& factory,
1371*89c4ff92SAndroid Build Coastguard Worker armnn::Graph& graph)
1372*89c4ff92SAndroid Build Coastguard Worker {
1373*89c4ff92SAndroid Build Coastguard Worker // Creates the layer we're testing.
1374*89c4ff92SAndroid Build Coastguard Worker FullyConnectedDescriptor layerDesc;
1375*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_BiasEnabled = true;
1376*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_TransposeWeightMatrix = true;
1377*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_ConstantWeights = false;
1378*89c4ff92SAndroid Build Coastguard Worker
1379*89c4ff92SAndroid Build Coastguard Worker FullyConnectedLayer* const layer = graph.AddLayer<FullyConnectedLayer>(layerDesc, "layer");
1380*89c4ff92SAndroid Build Coastguard Worker
1381*89c4ff92SAndroid Build Coastguard Worker float inputsQScale = 1.0f;
1382*89c4ff92SAndroid Build Coastguard Worker float outputQScale = DataType == armnn::DataType::QAsymmU8 ? 2.0f : 1.0;
1383*89c4ff92SAndroid Build Coastguard Worker
1384*89c4ff92SAndroid Build Coastguard Worker // Creates extra layers with weights and biases as input layers.
1385*89c4ff92SAndroid Build Coastguard Worker Layer* const input = graph.AddLayer<InputLayer>(1, "input");
1386*89c4ff92SAndroid Build Coastguard Worker Layer* const weights = graph.AddLayer<InputLayer>(2, "weights");
1387*89c4ff92SAndroid Build Coastguard Worker Layer* const biases = graph.AddLayer<InputLayer>(3, "biases");
1388*89c4ff92SAndroid Build Coastguard Worker Layer* const output = graph.AddLayer<OutputLayer>(0, "output");
1389*89c4ff92SAndroid Build Coastguard Worker
1390*89c4ff92SAndroid Build Coastguard Worker // Connects up.
1391*89c4ff92SAndroid Build Coastguard Worker Connect(input, layer, TensorInfo({3, 1, 4, 5}, DataType, inputsQScale), 0, 0);
1392*89c4ff92SAndroid Build Coastguard Worker Connect(weights, layer, TensorInfo({7, 20}, DataType, inputsQScale), 0, 1);
1393*89c4ff92SAndroid Build Coastguard Worker Connect(biases, layer, TensorInfo({7}, GetBiasDataType(DataType), inputsQScale), 0, 2);
1394*89c4ff92SAndroid Build Coastguard Worker Connect(layer, output, TensorInfo({3, 7}, DataType, outputQScale));
1395*89c4ff92SAndroid Build Coastguard Worker CreateTensorHandles(graph, factory);
1396*89c4ff92SAndroid Build Coastguard Worker
1397*89c4ff92SAndroid Build Coastguard Worker // Makes the workload and checks it.
1398*89c4ff92SAndroid Build Coastguard Worker auto workload = MakeAndCheckWorkload<FullyConnectedWorkload>(*layer, factory);
1399*89c4ff92SAndroid Build Coastguard Worker
1400*89c4ff92SAndroid Build Coastguard Worker FullyConnectedQueueDescriptor queueDescriptor = workload->GetData();
1401*89c4ff92SAndroid Build Coastguard Worker
1402*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Parameters.m_BiasEnabled == true);
1403*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Parameters.m_TransposeWeightMatrix == true);
1404*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Parameters.m_ConstantWeights == false);
1405*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Inputs.size() == 3);
1406*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Outputs.size() == 1);
1407*89c4ff92SAndroid Build Coastguard Worker
1408*89c4ff92SAndroid Build Coastguard Worker // Returns so we can do extra, backend-specific tests.
1409*89c4ff92SAndroid Build Coastguard Worker return workload;
1410*89c4ff92SAndroid Build Coastguard Worker }
1411*89c4ff92SAndroid Build Coastguard Worker
1412*89c4ff92SAndroid Build Coastguard Worker
1413*89c4ff92SAndroid Build Coastguard Worker template <typename NormalizationWorkload, armnn::DataType DataType>
CreateNormalizationWorkloadTest(armnn::IWorkloadFactory & factory,armnn::Graph & graph,DataLayout dataLayout=DataLayout::NCHW)1414*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<NormalizationWorkload> CreateNormalizationWorkloadTest(armnn::IWorkloadFactory& factory,
1415*89c4ff92SAndroid Build Coastguard Worker armnn::Graph& graph,
1416*89c4ff92SAndroid Build Coastguard Worker DataLayout dataLayout = DataLayout::NCHW)
1417*89c4ff92SAndroid Build Coastguard Worker {
1418*89c4ff92SAndroid Build Coastguard Worker // Creates the layer we're testing.
1419*89c4ff92SAndroid Build Coastguard Worker NormalizationDescriptor layerDesc;
1420*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_NormChannelType = NormalizationAlgorithmChannel::Across;
1421*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_NormMethodType = NormalizationAlgorithmMethod::LocalBrightness;
1422*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_NormSize = 3;
1423*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_Alpha = 0.5f;
1424*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_Beta = -1.0f;
1425*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_K = 0.2f;
1426*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_DataLayout = dataLayout;
1427*89c4ff92SAndroid Build Coastguard Worker
1428*89c4ff92SAndroid Build Coastguard Worker NormalizationLayer* layer = graph.AddLayer<NormalizationLayer>(layerDesc, "layer");
1429*89c4ff92SAndroid Build Coastguard Worker
1430*89c4ff92SAndroid Build Coastguard Worker // Creates extra layers.
1431*89c4ff92SAndroid Build Coastguard Worker Layer* const input = graph.AddLayer<InputLayer>(0, "input");
1432*89c4ff92SAndroid Build Coastguard Worker Layer* const output = graph.AddLayer<OutputLayer>(0, "output");
1433*89c4ff92SAndroid Build Coastguard Worker
1434*89c4ff92SAndroid Build Coastguard Worker TensorShape inputShape = (dataLayout == DataLayout::NCHW) ?
1435*89c4ff92SAndroid Build Coastguard Worker TensorShape{ 3, 5, 5, 1 } : TensorShape{ 3, 1, 5, 5 };
1436*89c4ff92SAndroid Build Coastguard Worker TensorShape outputShape = (dataLayout == DataLayout::NCHW) ?
1437*89c4ff92SAndroid Build Coastguard Worker TensorShape{ 3, 5, 5, 1 } : TensorShape{ 3, 1, 5, 5 };
1438*89c4ff92SAndroid Build Coastguard Worker
1439*89c4ff92SAndroid Build Coastguard Worker // Connects up.
1440*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputTensorInfo(inputShape, DataType);
1441*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo(outputShape, DataType);
1442*89c4ff92SAndroid Build Coastguard Worker Connect(input, layer, inputTensorInfo);
1443*89c4ff92SAndroid Build Coastguard Worker Connect(layer, output, outputTensorInfo);
1444*89c4ff92SAndroid Build Coastguard Worker CreateTensorHandles(graph, factory);
1445*89c4ff92SAndroid Build Coastguard Worker
1446*89c4ff92SAndroid Build Coastguard Worker // Makes the workload and checks it.
1447*89c4ff92SAndroid Build Coastguard Worker auto workload = MakeAndCheckWorkload<NormalizationWorkload>(*layer, factory);
1448*89c4ff92SAndroid Build Coastguard Worker
1449*89c4ff92SAndroid Build Coastguard Worker NormalizationQueueDescriptor queueDescriptor = workload->GetData();
1450*89c4ff92SAndroid Build Coastguard Worker CHECK((queueDescriptor.m_Parameters.m_NormChannelType == NormalizationAlgorithmChannel::Across));
1451*89c4ff92SAndroid Build Coastguard Worker CHECK((queueDescriptor.m_Parameters.m_NormMethodType == NormalizationAlgorithmMethod::LocalBrightness));
1452*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Parameters.m_NormSize == 3);
1453*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Parameters.m_Alpha == 0.5f);
1454*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Parameters.m_Beta == -1.0f);
1455*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Parameters.m_K == 0.2f);
1456*89c4ff92SAndroid Build Coastguard Worker CHECK((queueDescriptor.m_Parameters.m_DataLayout == dataLayout));
1457*89c4ff92SAndroid Build Coastguard Worker
1458*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Inputs.size() == 1);
1459*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Outputs.size() == 1);
1460*89c4ff92SAndroid Build Coastguard Worker
1461*89c4ff92SAndroid Build Coastguard Worker // Returns so we can do extra, backend-specific tests.
1462*89c4ff92SAndroid Build Coastguard Worker return workload;
1463*89c4ff92SAndroid Build Coastguard Worker }
1464*89c4ff92SAndroid Build Coastguard Worker
1465*89c4ff92SAndroid Build Coastguard Worker template <typename Pooling2dWorkload, armnn::DataType DataType>
CreatePooling2dWorkloadTest(armnn::IWorkloadFactory & factory,armnn::Graph & graph,DataLayout dataLayout=DataLayout::NCHW)1466*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<Pooling2dWorkload> CreatePooling2dWorkloadTest(armnn::IWorkloadFactory& factory,
1467*89c4ff92SAndroid Build Coastguard Worker armnn::Graph& graph,
1468*89c4ff92SAndroid Build Coastguard Worker DataLayout dataLayout = DataLayout::NCHW)
1469*89c4ff92SAndroid Build Coastguard Worker {
1470*89c4ff92SAndroid Build Coastguard Worker // Creates the layer we're testing.
1471*89c4ff92SAndroid Build Coastguard Worker Pooling2dDescriptor layerDesc;
1472*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_PoolType = PoolingAlgorithm::Average;
1473*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_PoolWidth = 3;
1474*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_PoolHeight = 3;
1475*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_PadLeft = 2;
1476*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_PadRight = 2;
1477*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_PadTop = 1;
1478*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_PadBottom = 1;
1479*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_StrideX = 2;
1480*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_StrideY = 3;
1481*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_OutputShapeRounding = OutputShapeRounding::Floor;
1482*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_DataLayout = dataLayout;
1483*89c4ff92SAndroid Build Coastguard Worker
1484*89c4ff92SAndroid Build Coastguard Worker Pooling2dLayer* const layer = graph.AddLayer<Pooling2dLayer>(layerDesc, "layer");
1485*89c4ff92SAndroid Build Coastguard Worker
1486*89c4ff92SAndroid Build Coastguard Worker // Create extra layers
1487*89c4ff92SAndroid Build Coastguard Worker Layer* const input = graph.AddLayer<InputLayer>(0, "input");
1488*89c4ff92SAndroid Build Coastguard Worker Layer* const output = graph.AddLayer<OutputLayer>(0, "output");
1489*89c4ff92SAndroid Build Coastguard Worker
1490*89c4ff92SAndroid Build Coastguard Worker TensorShape inputShape = (dataLayout == DataLayout::NCHW) ? TensorShape{3, 2, 5, 5} : TensorShape{3, 5, 5, 2};
1491*89c4ff92SAndroid Build Coastguard Worker TensorShape outputShape = (dataLayout == DataLayout::NCHW) ? TensorShape{3, 2, 2, 4} : TensorShape{3, 2, 4, 2};
1492*89c4ff92SAndroid Build Coastguard Worker
1493*89c4ff92SAndroid Build Coastguard Worker // Connect up
1494*89c4ff92SAndroid Build Coastguard Worker Connect(input, layer, TensorInfo(inputShape, DataType));
1495*89c4ff92SAndroid Build Coastguard Worker Connect(layer, output, TensorInfo(outputShape, DataType));
1496*89c4ff92SAndroid Build Coastguard Worker CreateTensorHandles(graph, factory);
1497*89c4ff92SAndroid Build Coastguard Worker
1498*89c4ff92SAndroid Build Coastguard Worker // Make the workload and checks it
1499*89c4ff92SAndroid Build Coastguard Worker auto workload = MakeAndCheckWorkload<Pooling2dWorkload>(*layer, factory);
1500*89c4ff92SAndroid Build Coastguard Worker
1501*89c4ff92SAndroid Build Coastguard Worker Pooling2dQueueDescriptor queueDescriptor = workload->GetData();
1502*89c4ff92SAndroid Build Coastguard Worker CHECK((queueDescriptor.m_Parameters.m_PoolType == PoolingAlgorithm::Average));
1503*89c4ff92SAndroid Build Coastguard Worker CHECK((queueDescriptor.m_Parameters.m_OutputShapeRounding == OutputShapeRounding::Floor));
1504*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Parameters.m_PoolWidth == 3);
1505*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Parameters.m_PoolHeight == 3);
1506*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Parameters.m_StrideX == 2);
1507*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Parameters.m_StrideY == 3);
1508*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Parameters.m_PadLeft == 2);
1509*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Parameters.m_PadRight == 2);
1510*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Parameters.m_PadTop == 1);
1511*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Parameters.m_PadBottom == 1);
1512*89c4ff92SAndroid Build Coastguard Worker CHECK((queueDescriptor.m_Parameters.m_DataLayout == dataLayout));
1513*89c4ff92SAndroid Build Coastguard Worker
1514*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Inputs.size() == 1);
1515*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Outputs.size() == 1);
1516*89c4ff92SAndroid Build Coastguard Worker
1517*89c4ff92SAndroid Build Coastguard Worker // Return so we can do extra, backend-specific tests
1518*89c4ff92SAndroid Build Coastguard Worker return workload;
1519*89c4ff92SAndroid Build Coastguard Worker }
1520*89c4ff92SAndroid Build Coastguard Worker
1521*89c4ff92SAndroid Build Coastguard Worker template <typename SoftmaxWorkload, armnn::DataType DataType>
CreateSoftmaxWorkloadTest(armnn::IWorkloadFactory & factory,armnn::Graph & graph)1522*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<SoftmaxWorkload> CreateSoftmaxWorkloadTest(armnn::IWorkloadFactory& factory,
1523*89c4ff92SAndroid Build Coastguard Worker armnn::Graph& graph)
1524*89c4ff92SAndroid Build Coastguard Worker {
1525*89c4ff92SAndroid Build Coastguard Worker // Create the layer we're testing.
1526*89c4ff92SAndroid Build Coastguard Worker SoftmaxDescriptor softmaxDescriptor;
1527*89c4ff92SAndroid Build Coastguard Worker // Set Axis to -1 if CL or Neon until further Axes are supported.
1528*89c4ff92SAndroid Build Coastguard Worker if (factory.GetBackendId() == armnn::Compute::CpuAcc || factory.GetBackendId() == armnn::Compute::GpuAcc)
1529*89c4ff92SAndroid Build Coastguard Worker {
1530*89c4ff92SAndroid Build Coastguard Worker softmaxDescriptor.m_Axis = -1;
1531*89c4ff92SAndroid Build Coastguard Worker }
1532*89c4ff92SAndroid Build Coastguard Worker
1533*89c4ff92SAndroid Build Coastguard Worker Layer* const layer = graph.AddLayer<SoftmaxLayer>(softmaxDescriptor, "layer");
1534*89c4ff92SAndroid Build Coastguard Worker // Create extra layers.
1535*89c4ff92SAndroid Build Coastguard Worker Layer* const input = graph.AddLayer<InputLayer>(0, "input");
1536*89c4ff92SAndroid Build Coastguard Worker Layer* const output = graph.AddLayer<OutputLayer>(0, "output");
1537*89c4ff92SAndroid Build Coastguard Worker
1538*89c4ff92SAndroid Build Coastguard Worker // Connect up
1539*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo tensorInfo({4, 1}, DataType);
1540*89c4ff92SAndroid Build Coastguard Worker if (DataType == armnn::DataType::QAsymmU8)
1541*89c4ff92SAndroid Build Coastguard Worker {
1542*89c4ff92SAndroid Build Coastguard Worker tensorInfo.SetQuantizationOffset(0);
1543*89c4ff92SAndroid Build Coastguard Worker tensorInfo.SetQuantizationScale(1.f / 256);
1544*89c4ff92SAndroid Build Coastguard Worker }
1545*89c4ff92SAndroid Build Coastguard Worker else if (DataType == armnn::DataType::QAsymmS8)
1546*89c4ff92SAndroid Build Coastguard Worker {
1547*89c4ff92SAndroid Build Coastguard Worker tensorInfo.SetQuantizationOffset(-128);
1548*89c4ff92SAndroid Build Coastguard Worker tensorInfo.SetQuantizationScale(1.f / 256);
1549*89c4ff92SAndroid Build Coastguard Worker }
1550*89c4ff92SAndroid Build Coastguard Worker
1551*89c4ff92SAndroid Build Coastguard Worker Connect(input, layer, tensorInfo);
1552*89c4ff92SAndroid Build Coastguard Worker Connect(layer, output, tensorInfo);
1553*89c4ff92SAndroid Build Coastguard Worker CreateTensorHandles(graph, factory);
1554*89c4ff92SAndroid Build Coastguard Worker
1555*89c4ff92SAndroid Build Coastguard Worker // Make the workload and checks it.
1556*89c4ff92SAndroid Build Coastguard Worker auto workload = MakeAndCheckWorkload<SoftmaxWorkload>(*layer, factory);
1557*89c4ff92SAndroid Build Coastguard Worker
1558*89c4ff92SAndroid Build Coastguard Worker SoftmaxQueueDescriptor queueDescriptor = workload->GetData();
1559*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Inputs.size() == 1);
1560*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Outputs.size() == 1);
1561*89c4ff92SAndroid Build Coastguard Worker
1562*89c4ff92SAndroid Build Coastguard Worker // Return so we can do extra, backend-specific tests.
1563*89c4ff92SAndroid Build Coastguard Worker return workload;
1564*89c4ff92SAndroid Build Coastguard Worker }
1565*89c4ff92SAndroid Build Coastguard Worker
1566*89c4ff92SAndroid Build Coastguard Worker template<typename SplitterWorkload, armnn::DataType DataType>
1567*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<SplitterWorkload>
CreateSplitterWorkloadTest(armnn::IWorkloadFactory & factory,armnn::Graph & graph)1568*89c4ff92SAndroid Build Coastguard Worker CreateSplitterWorkloadTest(armnn::IWorkloadFactory& factory, armnn::Graph& graph)
1569*89c4ff92SAndroid Build Coastguard Worker {
1570*89c4ff92SAndroid Build Coastguard Worker // Create the layer we're testing.
1571*89c4ff92SAndroid Build Coastguard Worker // NOTE: need three dimensions channels, height/y, width/x because the Compute
1572*89c4ff92SAndroid Build Coastguard Worker // library restricts subtensors to have the same x and y dimensions as
1573*89c4ff92SAndroid Build Coastguard Worker // their parent tensors, and therefore the origin on the x and y dimension
1574*89c4ff92SAndroid Build Coastguard Worker // has to be zero for any view. So we need a third dimension to split...
1575*89c4ff92SAndroid Build Coastguard Worker // NOTE: arguments are: number of views, number of dimensions.
1576*89c4ff92SAndroid Build Coastguard Worker ViewsDescriptor layerDesc(3, 3);
1577*89c4ff92SAndroid Build Coastguard Worker // NOTE: arguments are: view, dimension, value.
1578*89c4ff92SAndroid Build Coastguard Worker layerDesc.SetViewOriginCoord(0, 0, 0);
1579*89c4ff92SAndroid Build Coastguard Worker layerDesc.SetViewOriginCoord(1, 0, 1);
1580*89c4ff92SAndroid Build Coastguard Worker layerDesc.SetViewOriginCoord(2, 0, 3);
1581*89c4ff92SAndroid Build Coastguard Worker
1582*89c4ff92SAndroid Build Coastguard Worker Layer* const layer = graph.AddLayer<SplitterLayer>(layerDesc, "layer");
1583*89c4ff92SAndroid Build Coastguard Worker
1584*89c4ff92SAndroid Build Coastguard Worker // Adds extra layers.
1585*89c4ff92SAndroid Build Coastguard Worker Layer* const input = graph.AddLayer<InputLayer>(0, "input");
1586*89c4ff92SAndroid Build Coastguard Worker Layer* const output0 = graph.AddLayer<OutputLayer>(0, "output0");
1587*89c4ff92SAndroid Build Coastguard Worker Layer* const output1 = graph.AddLayer<OutputLayer>(1, "output1");
1588*89c4ff92SAndroid Build Coastguard Worker Layer* const output2 = graph.AddLayer<OutputLayer>(2, "output2");
1589*89c4ff92SAndroid Build Coastguard Worker
1590*89c4ff92SAndroid Build Coastguard Worker // Connects up.
1591*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo tensorInfo({5, 7, 7}, DataType);
1592*89c4ff92SAndroid Build Coastguard Worker Connect(input, layer, tensorInfo);
1593*89c4ff92SAndroid Build Coastguard Worker
1594*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo output0Info({1, 7, 7}, DataType);
1595*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo output1Info({2, 7, 7}, DataType);
1596*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo output2Info({2, 7, 7}, DataType);
1597*89c4ff92SAndroid Build Coastguard Worker
1598*89c4ff92SAndroid Build Coastguard Worker Connect(layer, output0, output0Info, 0, 0);
1599*89c4ff92SAndroid Build Coastguard Worker Connect(layer, output1, output1Info, 1, 0);
1600*89c4ff92SAndroid Build Coastguard Worker Connect(layer, output2, output2Info, 2, 0);
1601*89c4ff92SAndroid Build Coastguard Worker
1602*89c4ff92SAndroid Build Coastguard Worker CreateTensorHandles(graph, factory);
1603*89c4ff92SAndroid Build Coastguard Worker
1604*89c4ff92SAndroid Build Coastguard Worker // Makes the workload and checks it.
1605*89c4ff92SAndroid Build Coastguard Worker auto workload = MakeAndCheckWorkload<SplitterWorkload>(*layer, factory);
1606*89c4ff92SAndroid Build Coastguard Worker
1607*89c4ff92SAndroid Build Coastguard Worker SplitterQueueDescriptor queueDescriptor = workload->GetData();
1608*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Inputs.size() == 1);
1609*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Outputs.size() == 3);
1610*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_ViewOrigins.size() == 3);
1611*89c4ff92SAndroid Build Coastguard Worker
1612*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_ViewOrigins[0].m_Origin[0] == 0);
1613*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_ViewOrigins[1].m_Origin[0] == 1);
1614*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_ViewOrigins[2].m_Origin[0] == 3);
1615*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_ViewOrigins[0].m_Origin[1] == 0);
1616*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_ViewOrigins[1].m_Origin[1] == 0);
1617*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_ViewOrigins[2].m_Origin[1] == 0);
1618*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_ViewOrigins[0].m_Origin[2] == 0);
1619*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_ViewOrigins[1].m_Origin[2] == 0);
1620*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_ViewOrigins[2].m_Origin[2] == 0);
1621*89c4ff92SAndroid Build Coastguard Worker
1622*89c4ff92SAndroid Build Coastguard Worker // Returns so we can do extra, backend-specific tests.
1623*89c4ff92SAndroid Build Coastguard Worker return workload;
1624*89c4ff92SAndroid Build Coastguard Worker }
1625*89c4ff92SAndroid Build Coastguard Worker
1626*89c4ff92SAndroid Build Coastguard Worker /// This function constructs a graph with both a splitter and a concat, and returns a pair of the workloads.
1627*89c4ff92SAndroid Build Coastguard Worker template<typename SplitterWorkload, typename ConcatWorkload, armnn::DataType DataType>
1628*89c4ff92SAndroid Build Coastguard Worker std::pair<std::unique_ptr<SplitterWorkload>, std::unique_ptr<ConcatWorkload>>
CreateSplitterConcatWorkloadTest(armnn::IWorkloadFactory & factory,armnn::Graph & graph)1629*89c4ff92SAndroid Build Coastguard Worker CreateSplitterConcatWorkloadTest(armnn::IWorkloadFactory &factory, armnn::Graph &graph)
1630*89c4ff92SAndroid Build Coastguard Worker {
1631*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputTensorInfo({ 1, 2, 100, 10 }, DataType);
1632*89c4ff92SAndroid Build Coastguard Worker
1633*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo splitTensorInfo1({ 1, 1, 100, 10 }, DataType);
1634*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo splitTensorInfo2({ 1, 1, 100, 10 }, DataType);
1635*89c4ff92SAndroid Build Coastguard Worker
1636*89c4ff92SAndroid Build Coastguard Worker //Constructs the graph.
1637*89c4ff92SAndroid Build Coastguard Worker Layer* const input = graph.AddLayer<InputLayer>(0, "input");
1638*89c4ff92SAndroid Build Coastguard Worker
1639*89c4ff92SAndroid Build Coastguard Worker armnn::ViewsDescriptor splitterViews(2);
1640*89c4ff92SAndroid Build Coastguard Worker splitterViews.SetViewOriginCoord(0, 0, 0);
1641*89c4ff92SAndroid Build Coastguard Worker splitterViews.SetViewOriginCoord(0, 1, 0);
1642*89c4ff92SAndroid Build Coastguard Worker splitterViews.SetViewOriginCoord(0, 2, 0);
1643*89c4ff92SAndroid Build Coastguard Worker splitterViews.SetViewOriginCoord(0, 3, 0);
1644*89c4ff92SAndroid Build Coastguard Worker
1645*89c4ff92SAndroid Build Coastguard Worker splitterViews.SetViewOriginCoord(1, 0, 0);
1646*89c4ff92SAndroid Build Coastguard Worker splitterViews.SetViewOriginCoord(1, 1, 1);
1647*89c4ff92SAndroid Build Coastguard Worker splitterViews.SetViewOriginCoord(1, 2, 0);
1648*89c4ff92SAndroid Build Coastguard Worker splitterViews.SetViewOriginCoord(1, 3, 0);
1649*89c4ff92SAndroid Build Coastguard Worker
1650*89c4ff92SAndroid Build Coastguard Worker // create splitter layer
1651*89c4ff92SAndroid Build Coastguard Worker Layer* const splitter = graph.AddLayer<SplitterLayer>(splitterViews, "splitter");
1652*89c4ff92SAndroid Build Coastguard Worker CHECK(splitter);
1653*89c4ff92SAndroid Build Coastguard Worker
1654*89c4ff92SAndroid Build Coastguard Worker armnn::OriginsDescriptor concatViews(2);
1655*89c4ff92SAndroid Build Coastguard Worker concatViews.SetViewOriginCoord(0, 0, 0);
1656*89c4ff92SAndroid Build Coastguard Worker concatViews.SetViewOriginCoord(0, 1, 1);
1657*89c4ff92SAndroid Build Coastguard Worker concatViews.SetViewOriginCoord(0, 2, 0);
1658*89c4ff92SAndroid Build Coastguard Worker concatViews.SetViewOriginCoord(0, 3, 0);
1659*89c4ff92SAndroid Build Coastguard Worker
1660*89c4ff92SAndroid Build Coastguard Worker concatViews.SetViewOriginCoord(1, 0, 0);
1661*89c4ff92SAndroid Build Coastguard Worker concatViews.SetViewOriginCoord(1, 1, 0);
1662*89c4ff92SAndroid Build Coastguard Worker concatViews.SetViewOriginCoord(1, 2, 0);
1663*89c4ff92SAndroid Build Coastguard Worker concatViews.SetViewOriginCoord(1, 3, 0);
1664*89c4ff92SAndroid Build Coastguard Worker
1665*89c4ff92SAndroid Build Coastguard Worker // create concat layer
1666*89c4ff92SAndroid Build Coastguard Worker Layer* const concat = graph.AddLayer<ConcatLayer>(concatViews, "concat");
1667*89c4ff92SAndroid Build Coastguard Worker CHECK(concat);
1668*89c4ff92SAndroid Build Coastguard Worker
1669*89c4ff92SAndroid Build Coastguard Worker Layer* const output = graph.AddLayer<OutputLayer>(0, "output");
1670*89c4ff92SAndroid Build Coastguard Worker
1671*89c4ff92SAndroid Build Coastguard Worker // Adds connections.
1672*89c4ff92SAndroid Build Coastguard Worker // connect input to splitter
1673*89c4ff92SAndroid Build Coastguard Worker Connect(input, splitter, inputTensorInfo, 0, 0);
1674*89c4ff92SAndroid Build Coastguard Worker // connect splitter[0] to concat[1]
1675*89c4ff92SAndroid Build Coastguard Worker Connect(splitter, concat, splitTensorInfo1, 0, 1); // The splitter & concat are connected up.
1676*89c4ff92SAndroid Build Coastguard Worker // connect splitter[1] to concat[0]
1677*89c4ff92SAndroid Build Coastguard Worker Connect(splitter, concat, splitTensorInfo2, 1, 0); // So that the outputs are flipped round.
1678*89c4ff92SAndroid Build Coastguard Worker // connect concat to output
1679*89c4ff92SAndroid Build Coastguard Worker Connect(concat, output, inputTensorInfo, 0, 0);
1680*89c4ff92SAndroid Build Coastguard Worker
1681*89c4ff92SAndroid Build Coastguard Worker // created tensor handles
1682*89c4ff92SAndroid Build Coastguard Worker CreateTensorHandles(graph, factory);
1683*89c4ff92SAndroid Build Coastguard Worker
1684*89c4ff92SAndroid Build Coastguard Worker // created splitter workload
1685*89c4ff92SAndroid Build Coastguard Worker auto workloadSplitter = MakeAndCheckWorkload<SplitterWorkload>(*splitter, factory);
1686*89c4ff92SAndroid Build Coastguard Worker CHECK(workloadSplitter);
1687*89c4ff92SAndroid Build Coastguard Worker // created concat workload
1688*89c4ff92SAndroid Build Coastguard Worker auto workloadConcat = MakeAndCheckWorkload<ConcatWorkload>(*concat, factory);
1689*89c4ff92SAndroid Build Coastguard Worker CHECK(workloadConcat);
1690*89c4ff92SAndroid Build Coastguard Worker
1691*89c4ff92SAndroid Build Coastguard Worker return {std::move(workloadSplitter), std::move(workloadConcat)};
1692*89c4ff92SAndroid Build Coastguard Worker }
1693*89c4ff92SAndroid Build Coastguard Worker
1694*89c4ff92SAndroid Build Coastguard Worker
1695*89c4ff92SAndroid Build Coastguard Worker /// This function constructs a graph with a splitter with two outputs. Each of the outputs is then
1696*89c4ff92SAndroid Build Coastguard Worker /// connected to two different activation layers
1697*89c4ff92SAndroid Build Coastguard Worker template<typename SplitterWorkload, typename ActivationWorkload, armnn::DataType DataType>
CreateSplitterMultipleInputsOneOutputWorkloadTest(armnn::IWorkloadFactory & factory,armnn::Graph & graph,std::unique_ptr<SplitterWorkload> & wlSplitter,std::unique_ptr<ActivationWorkload> & wlActiv0_0,std::unique_ptr<ActivationWorkload> & wlActiv0_1,std::unique_ptr<ActivationWorkload> & wlActiv1_0,std::unique_ptr<ActivationWorkload> & wlActiv1_1)1698*89c4ff92SAndroid Build Coastguard Worker void CreateSplitterMultipleInputsOneOutputWorkloadTest(armnn::IWorkloadFactory& factory, armnn::Graph& graph,
1699*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<SplitterWorkload>& wlSplitter,
1700*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<ActivationWorkload>& wlActiv0_0,
1701*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<ActivationWorkload>& wlActiv0_1,
1702*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<ActivationWorkload>& wlActiv1_0,
1703*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<ActivationWorkload>& wlActiv1_1)
1704*89c4ff92SAndroid Build Coastguard Worker {
1705*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputTensorInfo ({ 1, 3, 100, 50 }, DataType);
1706*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo splitTensorInfo1({ 1, 1, 100, 50 }, DataType);
1707*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo splitTensorInfo2({ 1, 2, 100, 50 }, DataType);
1708*89c4ff92SAndroid Build Coastguard Worker
1709*89c4ff92SAndroid Build Coastguard Worker //Constructs the graph.
1710*89c4ff92SAndroid Build Coastguard Worker Layer* const input = graph.AddLayer<InputLayer>(0, "input");
1711*89c4ff92SAndroid Build Coastguard Worker
1712*89c4ff92SAndroid Build Coastguard Worker armnn::ViewsDescriptor splitterViews(2);
1713*89c4ff92SAndroid Build Coastguard Worker
1714*89c4ff92SAndroid Build Coastguard Worker splitterViews.SetViewOriginCoord(0, 0, 0);
1715*89c4ff92SAndroid Build Coastguard Worker splitterViews.SetViewOriginCoord(0, 1, 0);
1716*89c4ff92SAndroid Build Coastguard Worker splitterViews.SetViewOriginCoord(0, 2, 0);
1717*89c4ff92SAndroid Build Coastguard Worker splitterViews.SetViewOriginCoord(0, 3, 0);
1718*89c4ff92SAndroid Build Coastguard Worker
1719*89c4ff92SAndroid Build Coastguard Worker splitterViews.SetViewOriginCoord(1, 0, 0);
1720*89c4ff92SAndroid Build Coastguard Worker splitterViews.SetViewOriginCoord(1, 1, 1);
1721*89c4ff92SAndroid Build Coastguard Worker splitterViews.SetViewOriginCoord(1, 2, 0);
1722*89c4ff92SAndroid Build Coastguard Worker splitterViews.SetViewOriginCoord(1, 3, 0);
1723*89c4ff92SAndroid Build Coastguard Worker
1724*89c4ff92SAndroid Build Coastguard Worker Layer* const splitter = graph.AddLayer<SplitterLayer>(splitterViews, "splitter");
1725*89c4ff92SAndroid Build Coastguard Worker
1726*89c4ff92SAndroid Build Coastguard Worker armnn::ActivationDescriptor activationDesc;
1727*89c4ff92SAndroid Build Coastguard Worker
1728*89c4ff92SAndroid Build Coastguard Worker Layer* const activ0_0 = graph.AddLayer<ActivationLayer>(activationDesc, "activ0_0");
1729*89c4ff92SAndroid Build Coastguard Worker Layer* const activ0_1 = graph.AddLayer<ActivationLayer>(activationDesc, "activ0_1");
1730*89c4ff92SAndroid Build Coastguard Worker Layer* const activ1_0 = graph.AddLayer<ActivationLayer>(activationDesc, "activ1_0");
1731*89c4ff92SAndroid Build Coastguard Worker Layer* const activ1_1 = graph.AddLayer<ActivationLayer>(activationDesc, "activ1_1");
1732*89c4ff92SAndroid Build Coastguard Worker
1733*89c4ff92SAndroid Build Coastguard Worker Layer* const output1 = graph.AddLayer<OutputLayer>(1, "output1");
1734*89c4ff92SAndroid Build Coastguard Worker Layer* const output2 = graph.AddLayer<OutputLayer>(2, "output2");
1735*89c4ff92SAndroid Build Coastguard Worker Layer* const output3 = graph.AddLayer<OutputLayer>(3, "output3");
1736*89c4ff92SAndroid Build Coastguard Worker Layer* const output4 = graph.AddLayer<OutputLayer>(4, "output4");
1737*89c4ff92SAndroid Build Coastguard Worker
1738*89c4ff92SAndroid Build Coastguard Worker // Adds connections.
1739*89c4ff92SAndroid Build Coastguard Worker Connect(input, splitter, inputTensorInfo, 0, 0);
1740*89c4ff92SAndroid Build Coastguard Worker Connect(splitter, activ0_0, splitTensorInfo1, 0, 0);
1741*89c4ff92SAndroid Build Coastguard Worker Connect(splitter, activ0_1, splitTensorInfo1, 0, 0);
1742*89c4ff92SAndroid Build Coastguard Worker
1743*89c4ff92SAndroid Build Coastguard Worker Connect(splitter, activ1_0, splitTensorInfo2, 1, 0);
1744*89c4ff92SAndroid Build Coastguard Worker Connect(splitter, activ1_1, splitTensorInfo2, 1, 0);
1745*89c4ff92SAndroid Build Coastguard Worker
1746*89c4ff92SAndroid Build Coastguard Worker Connect(activ0_0, output1, splitTensorInfo1, 0, 0);
1747*89c4ff92SAndroid Build Coastguard Worker Connect(activ0_1, output2, splitTensorInfo1, 0, 0);
1748*89c4ff92SAndroid Build Coastguard Worker Connect(activ1_0, output3, splitTensorInfo2, 0, 0);
1749*89c4ff92SAndroid Build Coastguard Worker Connect(activ1_1, output4, splitTensorInfo2, 0, 0);
1750*89c4ff92SAndroid Build Coastguard Worker
1751*89c4ff92SAndroid Build Coastguard Worker CreateTensorHandles(graph, factory);
1752*89c4ff92SAndroid Build Coastguard Worker
1753*89c4ff92SAndroid Build Coastguard Worker auto workloadSplitter = MakeAndCheckWorkload<SplitterWorkload>(*splitter, factory);
1754*89c4ff92SAndroid Build Coastguard Worker auto workloadActiv0_0 = MakeAndCheckWorkload<ActivationWorkload>(*activ0_0, factory);
1755*89c4ff92SAndroid Build Coastguard Worker auto workloadActiv0_1 = MakeAndCheckWorkload<ActivationWorkload>(*activ0_1, factory);
1756*89c4ff92SAndroid Build Coastguard Worker auto workloadActiv1_0 = MakeAndCheckWorkload<ActivationWorkload>(*activ1_0, factory);
1757*89c4ff92SAndroid Build Coastguard Worker auto workloadActiv1_1 = MakeAndCheckWorkload<ActivationWorkload>(*activ1_1, factory);
1758*89c4ff92SAndroid Build Coastguard Worker
1759*89c4ff92SAndroid Build Coastguard Worker wlSplitter = std::move(workloadSplitter);
1760*89c4ff92SAndroid Build Coastguard Worker wlActiv0_0 = std::move(workloadActiv0_0);
1761*89c4ff92SAndroid Build Coastguard Worker wlActiv0_1 = std::move(workloadActiv0_1);
1762*89c4ff92SAndroid Build Coastguard Worker wlActiv1_0 = std::move(workloadActiv1_0);
1763*89c4ff92SAndroid Build Coastguard Worker wlActiv1_1 = std::move(workloadActiv1_1);
1764*89c4ff92SAndroid Build Coastguard Worker }
1765*89c4ff92SAndroid Build Coastguard Worker
1766*89c4ff92SAndroid Build Coastguard Worker template <typename ResizeWorkload, armnn::DataType DataType>
CreateResizeBilinearWorkloadTest(armnn::IWorkloadFactory & factory,armnn::Graph & graph,DataLayout dataLayout=DataLayout::NCHW)1767*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<ResizeWorkload> CreateResizeBilinearWorkloadTest(armnn::IWorkloadFactory& factory,
1768*89c4ff92SAndroid Build Coastguard Worker armnn::Graph& graph,
1769*89c4ff92SAndroid Build Coastguard Worker DataLayout dataLayout = DataLayout::NCHW)
1770*89c4ff92SAndroid Build Coastguard Worker {
1771*89c4ff92SAndroid Build Coastguard Worker TensorShape inputShape;
1772*89c4ff92SAndroid Build Coastguard Worker TensorShape outputShape;
1773*89c4ff92SAndroid Build Coastguard Worker
1774*89c4ff92SAndroid Build Coastguard Worker switch (dataLayout) {
1775*89c4ff92SAndroid Build Coastguard Worker case DataLayout::NHWC:
1776*89c4ff92SAndroid Build Coastguard Worker inputShape = { 2, 4, 4, 3 };
1777*89c4ff92SAndroid Build Coastguard Worker outputShape = { 2, 2, 2, 3 };
1778*89c4ff92SAndroid Build Coastguard Worker break;
1779*89c4ff92SAndroid Build Coastguard Worker case DataLayout::NCHW:
1780*89c4ff92SAndroid Build Coastguard Worker default:
1781*89c4ff92SAndroid Build Coastguard Worker inputShape = { 2, 3, 4, 4 };
1782*89c4ff92SAndroid Build Coastguard Worker outputShape = { 2, 3, 2, 2 };
1783*89c4ff92SAndroid Build Coastguard Worker }
1784*89c4ff92SAndroid Build Coastguard Worker
1785*89c4ff92SAndroid Build Coastguard Worker // Creates the layer we're testing.
1786*89c4ff92SAndroid Build Coastguard Worker ResizeDescriptor resizeDesc;
1787*89c4ff92SAndroid Build Coastguard Worker armnnUtils::DataLayoutIndexed dimensionIndices = dataLayout;
1788*89c4ff92SAndroid Build Coastguard Worker resizeDesc.m_Method = ResizeMethod::Bilinear;
1789*89c4ff92SAndroid Build Coastguard Worker resizeDesc.m_TargetWidth = outputShape[dimensionIndices.GetWidthIndex()];
1790*89c4ff92SAndroid Build Coastguard Worker resizeDesc.m_TargetHeight = outputShape[dimensionIndices.GetHeightIndex()];
1791*89c4ff92SAndroid Build Coastguard Worker resizeDesc.m_DataLayout = dataLayout;
1792*89c4ff92SAndroid Build Coastguard Worker Layer* const layer = graph.AddLayer<ResizeLayer>(resizeDesc, "resize");
1793*89c4ff92SAndroid Build Coastguard Worker
1794*89c4ff92SAndroid Build Coastguard Worker // Creates extra layers.
1795*89c4ff92SAndroid Build Coastguard Worker Layer* const input = graph.AddLayer<InputLayer>(0, "input");
1796*89c4ff92SAndroid Build Coastguard Worker Layer* const output = graph.AddLayer<OutputLayer>(0, "output");
1797*89c4ff92SAndroid Build Coastguard Worker
1798*89c4ff92SAndroid Build Coastguard Worker // Connects up.
1799*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputTensorInfo(inputShape, DataType);
1800*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo(outputShape, DataType);
1801*89c4ff92SAndroid Build Coastguard Worker Connect(input, layer, inputTensorInfo);
1802*89c4ff92SAndroid Build Coastguard Worker Connect(layer, output, outputTensorInfo);
1803*89c4ff92SAndroid Build Coastguard Worker CreateTensorHandles(graph, factory);
1804*89c4ff92SAndroid Build Coastguard Worker
1805*89c4ff92SAndroid Build Coastguard Worker // Makes the workload and checks it.
1806*89c4ff92SAndroid Build Coastguard Worker auto workload = MakeAndCheckWorkload<ResizeWorkload>(*layer, factory);
1807*89c4ff92SAndroid Build Coastguard Worker
1808*89c4ff92SAndroid Build Coastguard Worker auto queueDescriptor = workload->GetData();
1809*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Inputs.size() == 1);
1810*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Outputs.size() == 1);
1811*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Parameters.m_DataLayout == dataLayout);
1812*89c4ff92SAndroid Build Coastguard Worker
1813*89c4ff92SAndroid Build Coastguard Worker // Returns so we can do extra, backend-specific tests.
1814*89c4ff92SAndroid Build Coastguard Worker return workload;
1815*89c4ff92SAndroid Build Coastguard Worker }
1816*89c4ff92SAndroid Build Coastguard Worker
1817*89c4ff92SAndroid Build Coastguard Worker template <typename BatchToSpaceNdWorkload, armnn::DataType DataType>
CreateBatchToSpaceNdWorkloadTest(armnn::IWorkloadFactory & factory,armnn::Graph & graph)1818*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<BatchToSpaceNdWorkload> CreateBatchToSpaceNdWorkloadTest(armnn::IWorkloadFactory& factory,
1819*89c4ff92SAndroid Build Coastguard Worker armnn::Graph& graph)
1820*89c4ff92SAndroid Build Coastguard Worker {
1821*89c4ff92SAndroid Build Coastguard Worker BatchToSpaceNdDescriptor desc;
1822*89c4ff92SAndroid Build Coastguard Worker Layer* const layer = graph.AddLayer<BatchToSpaceNdLayer>(desc, "batchToSpace");
1823*89c4ff92SAndroid Build Coastguard Worker
1824*89c4ff92SAndroid Build Coastguard Worker // Creates extra layers.
1825*89c4ff92SAndroid Build Coastguard Worker Layer* const input = graph.AddLayer<InputLayer>(0, "input");
1826*89c4ff92SAndroid Build Coastguard Worker Layer* const output = graph.AddLayer<OutputLayer>(0, "output");
1827*89c4ff92SAndroid Build Coastguard Worker
1828*89c4ff92SAndroid Build Coastguard Worker // Connects up.
1829*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo tensorInfo({1, 1, 1, 1}, DataType);
1830*89c4ff92SAndroid Build Coastguard Worker
1831*89c4ff92SAndroid Build Coastguard Worker Connect(input, layer, tensorInfo);
1832*89c4ff92SAndroid Build Coastguard Worker Connect(layer, output, tensorInfo);
1833*89c4ff92SAndroid Build Coastguard Worker
1834*89c4ff92SAndroid Build Coastguard Worker CreateTensorHandles(graph, factory);
1835*89c4ff92SAndroid Build Coastguard Worker
1836*89c4ff92SAndroid Build Coastguard Worker // Makes the workload and checks it.
1837*89c4ff92SAndroid Build Coastguard Worker auto workload = MakeAndCheckWorkload<BatchToSpaceNdWorkload>(*layer, factory);
1838*89c4ff92SAndroid Build Coastguard Worker
1839*89c4ff92SAndroid Build Coastguard Worker BatchToSpaceNdQueueDescriptor queueDescriptor = workload->GetData();
1840*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Inputs.size() == 1);
1841*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Outputs.size() == 1);
1842*89c4ff92SAndroid Build Coastguard Worker
1843*89c4ff92SAndroid Build Coastguard Worker return workload;
1844*89c4ff92SAndroid Build Coastguard Worker }
1845*89c4ff92SAndroid Build Coastguard Worker
1846*89c4ff92SAndroid Build Coastguard Worker template <typename LogSoftmaxWorkload, armnn::DataType DataType>
CreateLogSoftmaxWorkloadTest(armnn::IWorkloadFactory & factory,armnn::Graph & graph)1847*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<LogSoftmaxWorkload> CreateLogSoftmaxWorkloadTest(armnn::IWorkloadFactory& factory,
1848*89c4ff92SAndroid Build Coastguard Worker armnn::Graph& graph)
1849*89c4ff92SAndroid Build Coastguard Worker {
1850*89c4ff92SAndroid Build Coastguard Worker // Create the layer we're testing.
1851*89c4ff92SAndroid Build Coastguard Worker LogSoftmaxDescriptor logSoftmaxDescriptor;
1852*89c4ff92SAndroid Build Coastguard Worker // Set Axis to -1 if CL or Neon until further Axes are supported.
1853*89c4ff92SAndroid Build Coastguard Worker if (factory.GetBackendId() == armnn::Compute::CpuAcc || factory.GetBackendId() == armnn::Compute::GpuAcc)
1854*89c4ff92SAndroid Build Coastguard Worker {
1855*89c4ff92SAndroid Build Coastguard Worker logSoftmaxDescriptor.m_Axis = -1;
1856*89c4ff92SAndroid Build Coastguard Worker }
1857*89c4ff92SAndroid Build Coastguard Worker
1858*89c4ff92SAndroid Build Coastguard Worker Layer* const layer = graph.AddLayer<LogSoftmaxLayer>(logSoftmaxDescriptor, "layer");
1859*89c4ff92SAndroid Build Coastguard Worker // Create extra layers.
1860*89c4ff92SAndroid Build Coastguard Worker Layer* const input = graph.AddLayer<InputLayer>(0, "input");
1861*89c4ff92SAndroid Build Coastguard Worker Layer* const output = graph.AddLayer<OutputLayer>(0, "output");
1862*89c4ff92SAndroid Build Coastguard Worker
1863*89c4ff92SAndroid Build Coastguard Worker // Connect up
1864*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo tensorInfo({4, 1}, DataType);
1865*89c4ff92SAndroid Build Coastguard Worker
1866*89c4ff92SAndroid Build Coastguard Worker Connect(input, layer, tensorInfo);
1867*89c4ff92SAndroid Build Coastguard Worker Connect(layer, output, tensorInfo);
1868*89c4ff92SAndroid Build Coastguard Worker CreateTensorHandles(graph, factory);
1869*89c4ff92SAndroid Build Coastguard Worker
1870*89c4ff92SAndroid Build Coastguard Worker // Make the workload and checks it.
1871*89c4ff92SAndroid Build Coastguard Worker auto workload = MakeAndCheckWorkload<LogSoftmaxWorkload>(*layer, factory);
1872*89c4ff92SAndroid Build Coastguard Worker
1873*89c4ff92SAndroid Build Coastguard Worker LogSoftmaxQueueDescriptor queueDescriptor = workload->GetData();
1874*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Inputs.size() == 1);
1875*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Outputs.size() == 1);
1876*89c4ff92SAndroid Build Coastguard Worker
1877*89c4ff92SAndroid Build Coastguard Worker // Return so we can do extra, backend-specific tests.
1878*89c4ff92SAndroid Build Coastguard Worker return workload;
1879*89c4ff92SAndroid Build Coastguard Worker }
1880*89c4ff92SAndroid Build Coastguard Worker
1881*89c4ff92SAndroid Build Coastguard Worker template <typename L2NormalizationWorkload, armnn::DataType DataType>
CreateL2NormalizationWorkloadTest(armnn::IWorkloadFactory & factory,armnn::Graph & graph,DataLayout dataLayout=DataLayout::NCHW)1882*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<L2NormalizationWorkload> CreateL2NormalizationWorkloadTest(armnn::IWorkloadFactory& factory,
1883*89c4ff92SAndroid Build Coastguard Worker armnn::Graph& graph, DataLayout dataLayout = DataLayout::NCHW)
1884*89c4ff92SAndroid Build Coastguard Worker {
1885*89c4ff92SAndroid Build Coastguard Worker // Creates the layer we're testing.
1886*89c4ff92SAndroid Build Coastguard Worker L2NormalizationDescriptor layerDesc;
1887*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_DataLayout = dataLayout;
1888*89c4ff92SAndroid Build Coastguard Worker
1889*89c4ff92SAndroid Build Coastguard Worker Layer* const layer = graph.AddLayer<L2NormalizationLayer>(layerDesc, "l2norm");
1890*89c4ff92SAndroid Build Coastguard Worker
1891*89c4ff92SAndroid Build Coastguard Worker // Creates extra layers.
1892*89c4ff92SAndroid Build Coastguard Worker Layer* const input = graph.AddLayer<InputLayer>(0, "input");
1893*89c4ff92SAndroid Build Coastguard Worker Layer* const output = graph.AddLayer<OutputLayer>(0, "output");
1894*89c4ff92SAndroid Build Coastguard Worker
1895*89c4ff92SAndroid Build Coastguard Worker TensorShape inputShape = (dataLayout == DataLayout::NCHW) ?
1896*89c4ff92SAndroid Build Coastguard Worker TensorShape{ 5, 20, 50, 67 } : TensorShape{ 5, 50, 67, 20 };
1897*89c4ff92SAndroid Build Coastguard Worker TensorShape outputShape = (dataLayout == DataLayout::NCHW) ?
1898*89c4ff92SAndroid Build Coastguard Worker TensorShape{ 5, 20, 50, 67 } : TensorShape{ 5, 50, 67, 20 };
1899*89c4ff92SAndroid Build Coastguard Worker
1900*89c4ff92SAndroid Build Coastguard Worker // Connects up.
1901*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputTensorInfo(inputShape, DataType);
1902*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo(outputShape, DataType);
1903*89c4ff92SAndroid Build Coastguard Worker Connect(input, layer, inputTensorInfo);
1904*89c4ff92SAndroid Build Coastguard Worker Connect(layer, output, outputTensorInfo);
1905*89c4ff92SAndroid Build Coastguard Worker CreateTensorHandles(graph, factory);
1906*89c4ff92SAndroid Build Coastguard Worker
1907*89c4ff92SAndroid Build Coastguard Worker // Makes the workload and checks it.
1908*89c4ff92SAndroid Build Coastguard Worker auto workload = MakeAndCheckWorkload<L2NormalizationWorkload>(*layer, factory);
1909*89c4ff92SAndroid Build Coastguard Worker
1910*89c4ff92SAndroid Build Coastguard Worker L2NormalizationQueueDescriptor queueDescriptor = workload->GetData();
1911*89c4ff92SAndroid Build Coastguard Worker CHECK((queueDescriptor.m_Parameters.m_DataLayout == dataLayout));
1912*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Inputs.size() == 1);
1913*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Outputs.size() == 1);
1914*89c4ff92SAndroid Build Coastguard Worker
1915*89c4ff92SAndroid Build Coastguard Worker // Returns so we can do extra, backend-specific tests.
1916*89c4ff92SAndroid Build Coastguard Worker return workload;
1917*89c4ff92SAndroid Build Coastguard Worker }
1918*89c4ff92SAndroid Build Coastguard Worker
1919*89c4ff92SAndroid Build Coastguard Worker template <typename ReshapeWorkload, armnn::DataType DataType>
CreateReshapeWorkloadTest(armnn::IWorkloadFactory & factory,armnn::Graph & graph)1920*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<ReshapeWorkload> CreateReshapeWorkloadTest(armnn::IWorkloadFactory& factory,
1921*89c4ff92SAndroid Build Coastguard Worker armnn::Graph& graph)
1922*89c4ff92SAndroid Build Coastguard Worker {
1923*89c4ff92SAndroid Build Coastguard Worker // Creates the layer we're testing.
1924*89c4ff92SAndroid Build Coastguard Worker TensorShape outputShape({ 1, 4 });
1925*89c4ff92SAndroid Build Coastguard Worker ReshapeDescriptor reshapeDesc;
1926*89c4ff92SAndroid Build Coastguard Worker reshapeDesc.m_TargetShape = outputShape;
1927*89c4ff92SAndroid Build Coastguard Worker Layer* const layer = graph.AddLayer<ReshapeLayer>(reshapeDesc, "layer");
1928*89c4ff92SAndroid Build Coastguard Worker
1929*89c4ff92SAndroid Build Coastguard Worker // Creates extra layers.
1930*89c4ff92SAndroid Build Coastguard Worker Layer* const input = graph.AddLayer<InputLayer>(0, "input");
1931*89c4ff92SAndroid Build Coastguard Worker Layer* const output = graph.AddLayer<OutputLayer>(0, "output");
1932*89c4ff92SAndroid Build Coastguard Worker
1933*89c4ff92SAndroid Build Coastguard Worker // Connects up.
1934*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputTensorInfo({ 4, 1 }, DataType);
1935*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo(outputShape, DataType);
1936*89c4ff92SAndroid Build Coastguard Worker Connect(input, layer, inputTensorInfo);
1937*89c4ff92SAndroid Build Coastguard Worker Connect(layer, output, outputTensorInfo);
1938*89c4ff92SAndroid Build Coastguard Worker CreateTensorHandles(graph, factory);
1939*89c4ff92SAndroid Build Coastguard Worker
1940*89c4ff92SAndroid Build Coastguard Worker // Makes the workload and checks it.
1941*89c4ff92SAndroid Build Coastguard Worker auto workload = MakeAndCheckWorkload<ReshapeWorkload>(*layer, factory);
1942*89c4ff92SAndroid Build Coastguard Worker
1943*89c4ff92SAndroid Build Coastguard Worker ReshapeQueueDescriptor queueDescriptor = workload->GetData();
1944*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Inputs.size() == 1);
1945*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Outputs.size() == 1);
1946*89c4ff92SAndroid Build Coastguard Worker
1947*89c4ff92SAndroid Build Coastguard Worker // Returns so we can do extra, backend-specific tests.
1948*89c4ff92SAndroid Build Coastguard Worker return workload;
1949*89c4ff92SAndroid Build Coastguard Worker }
1950*89c4ff92SAndroid Build Coastguard Worker
1951*89c4ff92SAndroid Build Coastguard Worker template <typename ConvertFp16ToFp32Float32Workload>
CreateConvertFp16ToFp32WorkloadTest(armnn::IWorkloadFactory & factory,armnn::Graph & graph)1952*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<ConvertFp16ToFp32Float32Workload> CreateConvertFp16ToFp32WorkloadTest(
1953*89c4ff92SAndroid Build Coastguard Worker armnn::IWorkloadFactory& factory, armnn::Graph& graph)
1954*89c4ff92SAndroid Build Coastguard Worker {
1955*89c4ff92SAndroid Build Coastguard Worker // Creates the layer we're testing.
1956*89c4ff92SAndroid Build Coastguard Worker ConvertFp16ToFp32Layer* const layer = graph.AddLayer<ConvertFp16ToFp32Layer>("Fp16ToFp32Converter");
1957*89c4ff92SAndroid Build Coastguard Worker
1958*89c4ff92SAndroid Build Coastguard Worker // Creates extra layers.
1959*89c4ff92SAndroid Build Coastguard Worker Layer* const input = graph.AddLayer<InputLayer>(0, "input");
1960*89c4ff92SAndroid Build Coastguard Worker Layer* const output = graph.AddLayer<OutputLayer>(0, "output");
1961*89c4ff92SAndroid Build Coastguard Worker
1962*89c4ff92SAndroid Build Coastguard Worker // Connects up.
1963*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputTensorInfo({1, 3, 2, 3}, armnn::DataType::Float16);
1964*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo({1, 3, 2, 3}, armnn::DataType::Float32);
1965*89c4ff92SAndroid Build Coastguard Worker Connect(input, layer, inputTensorInfo);
1966*89c4ff92SAndroid Build Coastguard Worker Connect(layer, output, outputTensorInfo);
1967*89c4ff92SAndroid Build Coastguard Worker CreateTensorHandles(graph, factory);
1968*89c4ff92SAndroid Build Coastguard Worker
1969*89c4ff92SAndroid Build Coastguard Worker // Makes the workload and checks it.
1970*89c4ff92SAndroid Build Coastguard Worker auto workload = MakeAndCheckWorkload<ConvertFp16ToFp32Float32Workload>(*layer, factory);
1971*89c4ff92SAndroid Build Coastguard Worker
1972*89c4ff92SAndroid Build Coastguard Worker ConvertFp16ToFp32QueueDescriptor queueDescriptor = workload->GetData();
1973*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Inputs.size() == 1);
1974*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Outputs.size() == 1);
1975*89c4ff92SAndroid Build Coastguard Worker
1976*89c4ff92SAndroid Build Coastguard Worker // Returns so we can do extra, backend-specific tests.
1977*89c4ff92SAndroid Build Coastguard Worker return workload;
1978*89c4ff92SAndroid Build Coastguard Worker }
1979*89c4ff92SAndroid Build Coastguard Worker
1980*89c4ff92SAndroid Build Coastguard Worker template <typename ConvertFp32ToFp16Float16Workload>
CreateConvertFp32ToFp16WorkloadTest(armnn::IWorkloadFactory & factory,armnn::Graph & graph)1981*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<ConvertFp32ToFp16Float16Workload> CreateConvertFp32ToFp16WorkloadTest(
1982*89c4ff92SAndroid Build Coastguard Worker armnn::IWorkloadFactory& factory, armnn::Graph& graph)
1983*89c4ff92SAndroid Build Coastguard Worker {
1984*89c4ff92SAndroid Build Coastguard Worker // Creates the layer we're testing.
1985*89c4ff92SAndroid Build Coastguard Worker ConvertFp32ToFp16Layer* const layer = graph.AddLayer<ConvertFp32ToFp16Layer>("Fp32ToFp16Converter");
1986*89c4ff92SAndroid Build Coastguard Worker
1987*89c4ff92SAndroid Build Coastguard Worker // Creates extra layers.
1988*89c4ff92SAndroid Build Coastguard Worker Layer* const input = graph.AddLayer<InputLayer>(0, "input");
1989*89c4ff92SAndroid Build Coastguard Worker Layer* const output = graph.AddLayer<OutputLayer>(0, "output");
1990*89c4ff92SAndroid Build Coastguard Worker
1991*89c4ff92SAndroid Build Coastguard Worker // Connects up.
1992*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputTensorInfo({1, 3, 2, 3}, armnn::DataType::Float32);
1993*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo({1, 3, 2, 3}, armnn::DataType::Float16);
1994*89c4ff92SAndroid Build Coastguard Worker Connect(input, layer, inputTensorInfo);
1995*89c4ff92SAndroid Build Coastguard Worker Connect(layer, output, outputTensorInfo);
1996*89c4ff92SAndroid Build Coastguard Worker CreateTensorHandles(graph, factory);
1997*89c4ff92SAndroid Build Coastguard Worker
1998*89c4ff92SAndroid Build Coastguard Worker // Makes the workload and checks it.
1999*89c4ff92SAndroid Build Coastguard Worker auto workload = MakeAndCheckWorkload<ConvertFp32ToFp16Float16Workload>(*layer, factory);
2000*89c4ff92SAndroid Build Coastguard Worker
2001*89c4ff92SAndroid Build Coastguard Worker ConvertFp32ToFp16QueueDescriptor queueDescriptor = workload->GetData();
2002*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Inputs.size() == 1);
2003*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Outputs.size() == 1);
2004*89c4ff92SAndroid Build Coastguard Worker
2005*89c4ff92SAndroid Build Coastguard Worker // Returns so we can do extra, backend-specific tests.
2006*89c4ff92SAndroid Build Coastguard Worker return workload;
2007*89c4ff92SAndroid Build Coastguard Worker }
2008*89c4ff92SAndroid Build Coastguard Worker
2009*89c4ff92SAndroid Build Coastguard Worker template <typename MeanWorkload, armnn::DataType DataType>
CreateMeanWorkloadTest(armnn::IWorkloadFactory & factory,armnn::Graph & graph)2010*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<MeanWorkload> CreateMeanWorkloadTest(armnn::IWorkloadFactory& factory, armnn::Graph& graph)
2011*89c4ff92SAndroid Build Coastguard Worker {
2012*89c4ff92SAndroid Build Coastguard Worker // Reduce along the first and second dimensions, and do not keep the reduced dimensions.
2013*89c4ff92SAndroid Build Coastguard Worker MeanDescriptor descriptor({ 1, 2 }, false);
2014*89c4ff92SAndroid Build Coastguard Worker
2015*89c4ff92SAndroid Build Coastguard Worker // Creates the layer we're testing.
2016*89c4ff92SAndroid Build Coastguard Worker Layer* const layer = graph.AddLayer<MeanLayer>(descriptor, "mean");
2017*89c4ff92SAndroid Build Coastguard Worker
2018*89c4ff92SAndroid Build Coastguard Worker // Creates extra layers.
2019*89c4ff92SAndroid Build Coastguard Worker Layer* const input = graph.AddLayer<InputLayer>(0, "input");
2020*89c4ff92SAndroid Build Coastguard Worker Layer* const output = graph.AddLayer<OutputLayer>(0, "output");
2021*89c4ff92SAndroid Build Coastguard Worker
2022*89c4ff92SAndroid Build Coastguard Worker // Connects up.
2023*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputTensorInfo({ 1, 3, 7, 4 }, DataType);
2024*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo({ 1, 4 }, DataType);
2025*89c4ff92SAndroid Build Coastguard Worker Connect(input, layer, inputTensorInfo);
2026*89c4ff92SAndroid Build Coastguard Worker Connect(layer, output, outputTensorInfo);
2027*89c4ff92SAndroid Build Coastguard Worker CreateTensorHandles(graph, factory);
2028*89c4ff92SAndroid Build Coastguard Worker
2029*89c4ff92SAndroid Build Coastguard Worker // Makes the workload and checks it.
2030*89c4ff92SAndroid Build Coastguard Worker auto workload = MakeAndCheckWorkload<MeanWorkload>(*layer, factory);
2031*89c4ff92SAndroid Build Coastguard Worker
2032*89c4ff92SAndroid Build Coastguard Worker MeanQueueDescriptor queueDescriptor = workload->GetData();
2033*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Parameters.m_Axis == descriptor.m_Axis);
2034*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Parameters.m_KeepDims == descriptor.m_KeepDims);
2035*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Inputs.size() == 1);
2036*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Outputs.size() == 1);
2037*89c4ff92SAndroid Build Coastguard Worker
2038*89c4ff92SAndroid Build Coastguard Worker // Returns so we can do extra, backend-specific tests.
2039*89c4ff92SAndroid Build Coastguard Worker return workload;
2040*89c4ff92SAndroid Build Coastguard Worker }
2041*89c4ff92SAndroid Build Coastguard Worker
2042*89c4ff92SAndroid Build Coastguard Worker template<typename ConcatWorkload, armnn::DataType DataType>
CreateConcatWorkloadTest(armnn::IWorkloadFactory & factory,armnn::Graph & graph,const armnn::TensorShape & outputShape,unsigned int concatAxis)2043*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<ConcatWorkload> CreateConcatWorkloadTest(armnn::IWorkloadFactory &factory,
2044*89c4ff92SAndroid Build Coastguard Worker armnn::Graph &graph,
2045*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorShape &outputShape,
2046*89c4ff92SAndroid Build Coastguard Worker unsigned int concatAxis)
2047*89c4ff92SAndroid Build Coastguard Worker {
2048*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputTensorInfo({ 2, 3, 2, 5 }, DataType);
2049*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo(outputShape, DataType);
2050*89c4ff92SAndroid Build Coastguard Worker
2051*89c4ff92SAndroid Build Coastguard Worker // Constructs the graph.
2052*89c4ff92SAndroid Build Coastguard Worker Layer* const input0 = graph.AddLayer<InputLayer>(0, "input0");
2053*89c4ff92SAndroid Build Coastguard Worker Layer* const input1 = graph.AddLayer<InputLayer>(1, "input1");
2054*89c4ff92SAndroid Build Coastguard Worker armnn::OriginsDescriptor descriptor;
2055*89c4ff92SAndroid Build Coastguard Worker
2056*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::TensorShape> inputShapes{{ 2, 3, 2, 5 }, { 2, 3, 2, 5 }};
2057*89c4ff92SAndroid Build Coastguard Worker
2058*89c4ff92SAndroid Build Coastguard Worker descriptor = CreateDescriptorForConcatenation(inputShapes.begin(),
2059*89c4ff92SAndroid Build Coastguard Worker inputShapes.end(),
2060*89c4ff92SAndroid Build Coastguard Worker concatAxis);
2061*89c4ff92SAndroid Build Coastguard Worker
2062*89c4ff92SAndroid Build Coastguard Worker // create concat layer
2063*89c4ff92SAndroid Build Coastguard Worker Layer* const concat = graph.AddLayer<ConcatLayer>(descriptor, "concat");
2064*89c4ff92SAndroid Build Coastguard Worker CHECK(concat);
2065*89c4ff92SAndroid Build Coastguard Worker
2066*89c4ff92SAndroid Build Coastguard Worker Layer* const output = graph.AddLayer<OutputLayer>(0, "output");
2067*89c4ff92SAndroid Build Coastguard Worker
2068*89c4ff92SAndroid Build Coastguard Worker // Adds connections.
2069*89c4ff92SAndroid Build Coastguard Worker // connect input0 to concat
2070*89c4ff92SAndroid Build Coastguard Worker Connect(input0, concat, inputTensorInfo, 0, 0);
2071*89c4ff92SAndroid Build Coastguard Worker // connect input1 to concat
2072*89c4ff92SAndroid Build Coastguard Worker Connect(input1, concat, inputTensorInfo, 0, 1);
2073*89c4ff92SAndroid Build Coastguard Worker // connect concat to output
2074*89c4ff92SAndroid Build Coastguard Worker Connect(concat, output, outputTensorInfo, 0, 0);
2075*89c4ff92SAndroid Build Coastguard Worker
2076*89c4ff92SAndroid Build Coastguard Worker // create tensor handles
2077*89c4ff92SAndroid Build Coastguard Worker CreateTensorHandles(graph, factory);
2078*89c4ff92SAndroid Build Coastguard Worker
2079*89c4ff92SAndroid Build Coastguard Worker // create concat workload
2080*89c4ff92SAndroid Build Coastguard Worker auto workloadConcat = MakeAndCheckWorkload<ConcatWorkload>(*concat, factory);
2081*89c4ff92SAndroid Build Coastguard Worker CHECK(workloadConcat);
2082*89c4ff92SAndroid Build Coastguard Worker
2083*89c4ff92SAndroid Build Coastguard Worker return workloadConcat;
2084*89c4ff92SAndroid Build Coastguard Worker }
2085*89c4ff92SAndroid Build Coastguard Worker
2086*89c4ff92SAndroid Build Coastguard Worker template <typename PreCompiledWorkload, armnn::DataType dataType>
CreatePreCompiledWorkloadTest(armnn::IWorkloadFactory & factory,armnn::Graph & graph,bool biasEnabled=false)2087*89c4ff92SAndroid Build Coastguard Worker std::pair<armnn::IOptimizedNetworkPtr, std::unique_ptr<PreCompiledWorkload>> CreatePreCompiledWorkloadTest(
2088*89c4ff92SAndroid Build Coastguard Worker armnn::IWorkloadFactory& factory,
2089*89c4ff92SAndroid Build Coastguard Worker armnn::Graph& graph,
2090*89c4ff92SAndroid Build Coastguard Worker bool biasEnabled = false)
2091*89c4ff92SAndroid Build Coastguard Worker {
2092*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(graph);
2093*89c4ff92SAndroid Build Coastguard Worker
2094*89c4ff92SAndroid Build Coastguard Worker // build up the structure of the network
2095*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr net(armnn::INetwork::Create());
2096*89c4ff92SAndroid Build Coastguard Worker
2097*89c4ff92SAndroid Build Coastguard Worker // Add an input layer
2098*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* const inputLayer = net->AddInputLayer(0, "input layer");
2099*89c4ff92SAndroid Build Coastguard Worker CHECK(inputLayer);
2100*89c4ff92SAndroid Build Coastguard Worker
2101*89c4ff92SAndroid Build Coastguard Worker // ArmNN weights tensor shape is OIHW (out channels, in channels, height, width) for NCHW
2102*89c4ff92SAndroid Build Coastguard Worker // ArmNN weights tensor shape is OHWI (out channels, height, width, in channels) for NHWC
2103*89c4ff92SAndroid Build Coastguard Worker // this test is using NHWC, so the weights shape is OHWI
2104*89c4ff92SAndroid Build Coastguard Worker TensorInfo weightsTensorInfo(TensorShape({16, 1, 1, 16}), dataType, 0.9f, 0, true);
2105*89c4ff92SAndroid Build Coastguard Worker unsigned int weightsLength = weightsTensorInfo.GetNumElements();
2106*89c4ff92SAndroid Build Coastguard Worker
2107*89c4ff92SAndroid Build Coastguard Worker using WeightType = armnn::ResolveType<dataType>;
2108*89c4ff92SAndroid Build Coastguard Worker std::vector<WeightType> convWeightsData(weightsLength);
2109*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < weightsLength; ++i)
2110*89c4ff92SAndroid Build Coastguard Worker {
2111*89c4ff92SAndroid Build Coastguard Worker convWeightsData[i] = static_cast<WeightType>(i);
2112*89c4ff92SAndroid Build Coastguard Worker }
2113*89c4ff92SAndroid Build Coastguard Worker
2114*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor weights(weightsTensorInfo, convWeightsData);
2115*89c4ff92SAndroid Build Coastguard Worker
2116*89c4ff92SAndroid Build Coastguard Worker // Add a layer that can be used in the PreCompiled layer
2117*89c4ff92SAndroid Build Coastguard Worker armnn::Convolution2dDescriptor convDesc2d;
2118*89c4ff92SAndroid Build Coastguard Worker convDesc2d.m_StrideX = 1;
2119*89c4ff92SAndroid Build Coastguard Worker convDesc2d.m_StrideY = 1;
2120*89c4ff92SAndroid Build Coastguard Worker convDesc2d.m_BiasEnabled = biasEnabled;
2121*89c4ff92SAndroid Build Coastguard Worker convDesc2d.m_DataLayout = armnn::DataLayout::NHWC;
2122*89c4ff92SAndroid Build Coastguard Worker
2123*89c4ff92SAndroid Build Coastguard Worker
2124*89c4ff92SAndroid Build Coastguard Worker const std::string convLayerName("conv layer");
2125*89c4ff92SAndroid Build Coastguard Worker
2126*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* convLayer = net->AddConvolution2dLayer(convDesc2d, convLayerName.c_str());
2127*89c4ff92SAndroid Build Coastguard Worker
2128*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* weightsLayer = net->AddConstantLayer(weights);
2129*89c4ff92SAndroid Build Coastguard Worker weightsLayer->GetOutputSlot(0).SetTensorInfo(weights.GetInfo());
2130*89c4ff92SAndroid Build Coastguard Worker weightsLayer->GetOutputSlot(0).Connect(convLayer->GetInputSlot(1u));
2131*89c4ff92SAndroid Build Coastguard Worker
2132*89c4ff92SAndroid Build Coastguard Worker if (biasEnabled)
2133*89c4ff92SAndroid Build Coastguard Worker {
2134*89c4ff92SAndroid Build Coastguard Worker constexpr armnn::DataType biasDataType = ( dataType == armnn::DataType::QAsymmU8) ?
2135*89c4ff92SAndroid Build Coastguard Worker armnn::DataType::Signed32 : armnn::DataType::Float32;
2136*89c4ff92SAndroid Build Coastguard Worker
2137*89c4ff92SAndroid Build Coastguard Worker TensorInfo biasTensorInfo(TensorShape({16}), biasDataType, 0.9f * 0.9f, 0, true);
2138*89c4ff92SAndroid Build Coastguard Worker unsigned int biasLength = biasTensorInfo.GetNumElements();
2139*89c4ff92SAndroid Build Coastguard Worker
2140*89c4ff92SAndroid Build Coastguard Worker using BiasType = armnn::ResolveType<biasDataType>;
2141*89c4ff92SAndroid Build Coastguard Worker std::vector<BiasType> biasData(biasLength);
2142*89c4ff92SAndroid Build Coastguard Worker std::fill(biasData.begin(), biasData.end(), static_cast<BiasType>(0));
2143*89c4ff92SAndroid Build Coastguard Worker
2144*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor biases(biasTensorInfo, biasData);
2145*89c4ff92SAndroid Build Coastguard Worker
2146*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* biasLayer = net->AddConstantLayer(biases);
2147*89c4ff92SAndroid Build Coastguard Worker
2148*89c4ff92SAndroid Build Coastguard Worker biasLayer->GetOutputSlot(0).SetTensorInfo(biases.GetInfo());
2149*89c4ff92SAndroid Build Coastguard Worker biasLayer->GetOutputSlot(0).Connect(convLayer->GetInputSlot(2u));
2150*89c4ff92SAndroid Build Coastguard Worker }
2151*89c4ff92SAndroid Build Coastguard Worker
2152*89c4ff92SAndroid Build Coastguard Worker CHECK(convLayer);
2153*89c4ff92SAndroid Build Coastguard Worker
2154*89c4ff92SAndroid Build Coastguard Worker // Add an output layer
2155*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* const outputLayer = net->AddOutputLayer(0, "output layer");
2156*89c4ff92SAndroid Build Coastguard Worker CHECK(outputLayer);
2157*89c4ff92SAndroid Build Coastguard Worker
2158*89c4ff92SAndroid Build Coastguard Worker // set the tensors in the network (NHWC format)
2159*89c4ff92SAndroid Build Coastguard Worker TensorInfo inputTensorInfo(TensorShape({ 1, 16, 16, 16 }), dataType);
2160*89c4ff92SAndroid Build Coastguard Worker if (dataType == armnn::DataType::QAsymmU8)
2161*89c4ff92SAndroid Build Coastguard Worker {
2162*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo.SetQuantizationOffset(0);
2163*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo.SetQuantizationScale(0.9f);
2164*89c4ff92SAndroid Build Coastguard Worker }
2165*89c4ff92SAndroid Build Coastguard Worker
2166*89c4ff92SAndroid Build Coastguard Worker TensorInfo outputTensorInfo(TensorShape({1, 16, 16, 16}), dataType);
2167*89c4ff92SAndroid Build Coastguard Worker if (dataType == armnn::DataType::QAsymmU8)
2168*89c4ff92SAndroid Build Coastguard Worker {
2169*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo.SetQuantizationOffset(0);
2170*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo.SetQuantizationScale(0.9f);
2171*89c4ff92SAndroid Build Coastguard Worker }
2172*89c4ff92SAndroid Build Coastguard Worker
2173*89c4ff92SAndroid Build Coastguard Worker // Connect the layers
2174*89c4ff92SAndroid Build Coastguard Worker inputLayer->GetOutputSlot(0).Connect(convLayer->GetInputSlot(0));
2175*89c4ff92SAndroid Build Coastguard Worker inputLayer->GetOutputSlot(0).SetTensorInfo(inputTensorInfo);
2176*89c4ff92SAndroid Build Coastguard Worker
2177*89c4ff92SAndroid Build Coastguard Worker convLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
2178*89c4ff92SAndroid Build Coastguard Worker convLayer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
2179*89c4ff92SAndroid Build Coastguard Worker
2180*89c4ff92SAndroid Build Coastguard Worker // Optimize the network for the backend supported by the factory
2181*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends = {factory.GetBackendId()};
2182*89c4ff92SAndroid Build Coastguard Worker armnn::IRuntime::CreationOptions options;
2183*89c4ff92SAndroid Build Coastguard Worker armnn::IRuntimePtr runtime(armnn::IRuntime::Create(options));
2184*89c4ff92SAndroid Build Coastguard Worker armnn::OptimizerOptionsOpaque optimizerOptions;
2185*89c4ff92SAndroid Build Coastguard Worker armnn::IOptimizedNetworkPtr optimizedNet = armnn::Optimize(*net, backends, runtime->GetDeviceSpec(),
2186*89c4ff92SAndroid Build Coastguard Worker optimizerOptions);
2187*89c4ff92SAndroid Build Coastguard Worker CHECK(optimizedNet != nullptr);
2188*89c4ff92SAndroid Build Coastguard Worker
2189*89c4ff92SAndroid Build Coastguard Worker // Find the PreCompiled layer in the optimised graph
2190*89c4ff92SAndroid Build Coastguard Worker armnn::Graph& optimisedGraph = GetGraphForTesting(optimizedNet.get());
2191*89c4ff92SAndroid Build Coastguard Worker Layer* preCompiledLayer = nullptr;
2192*89c4ff92SAndroid Build Coastguard Worker for (auto& layer : optimisedGraph)
2193*89c4ff92SAndroid Build Coastguard Worker {
2194*89c4ff92SAndroid Build Coastguard Worker if (layer->GetType() == LayerType::PreCompiled)
2195*89c4ff92SAndroid Build Coastguard Worker {
2196*89c4ff92SAndroid Build Coastguard Worker preCompiledLayer = layer;
2197*89c4ff92SAndroid Build Coastguard Worker }
2198*89c4ff92SAndroid Build Coastguard Worker }
2199*89c4ff92SAndroid Build Coastguard Worker CHECK(preCompiledLayer != nullptr);
2200*89c4ff92SAndroid Build Coastguard Worker
2201*89c4ff92SAndroid Build Coastguard Worker // Create the TensorHandles.
2202*89c4ff92SAndroid Build Coastguard Worker CreateTensorHandles(optimisedGraph, factory);
2203*89c4ff92SAndroid Build Coastguard Worker
2204*89c4ff92SAndroid Build Coastguard Worker // Make the workload and check it.
2205*89c4ff92SAndroid Build Coastguard Worker auto workload = MakeAndCheckWorkload<PreCompiledWorkload>(*preCompiledLayer, factory);
2206*89c4ff92SAndroid Build Coastguard Worker
2207*89c4ff92SAndroid Build Coastguard Worker PreCompiledQueueDescriptor queueDescriptor = workload->GetData();
2208*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Inputs.size() == 1);
2209*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Outputs.size() == 1);
2210*89c4ff92SAndroid Build Coastguard Worker
2211*89c4ff92SAndroid Build Coastguard Worker // Returns the workload so we can do extra, backend-specific tests.
2212*89c4ff92SAndroid Build Coastguard Worker // NOTE: We need to return the optimised network as well, otherwise it gets
2213*89c4ff92SAndroid Build Coastguard Worker // out of scope and the tensor handles get destructed
2214*89c4ff92SAndroid Build Coastguard Worker return std::make_pair(std::move(optimizedNet), std::move(workload));
2215*89c4ff92SAndroid Build Coastguard Worker }
2216*89c4ff92SAndroid Build Coastguard Worker
2217*89c4ff92SAndroid Build Coastguard Worker template<typename ConstantWorkload, armnn::DataType DataType>
CreateConstantWorkloadTest(armnn::IWorkloadFactory & factory,armnn::Graph & graph,const armnn::TensorShape & outputShape)2218*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<ConstantWorkload> CreateConstantWorkloadTest(armnn::IWorkloadFactory& factory,
2219*89c4ff92SAndroid Build Coastguard Worker armnn::Graph& graph,
2220*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorShape& outputShape)
2221*89c4ff92SAndroid Build Coastguard Worker {
2222*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo(outputShape, DataType);
2223*89c4ff92SAndroid Build Coastguard Worker
2224*89c4ff92SAndroid Build Coastguard Worker // create constant layer
2225*89c4ff92SAndroid Build Coastguard Worker auto constant = graph.AddLayer<ConstantLayer>("constant");
2226*89c4ff92SAndroid Build Coastguard Worker CHECK(constant);
2227*89c4ff92SAndroid Build Coastguard Worker constant->m_LayerOutput = std::make_unique<ScopedTensorHandle>(outputTensorInfo);
2228*89c4ff92SAndroid Build Coastguard Worker
2229*89c4ff92SAndroid Build Coastguard Worker Layer* const output = graph.AddLayer<OutputLayer>(0, "output");
2230*89c4ff92SAndroid Build Coastguard Worker
2231*89c4ff92SAndroid Build Coastguard Worker // Adds connections.
2232*89c4ff92SAndroid Build Coastguard Worker // connect constant to output
2233*89c4ff92SAndroid Build Coastguard Worker Connect(constant, output, outputTensorInfo, 0, 0);
2234*89c4ff92SAndroid Build Coastguard Worker
2235*89c4ff92SAndroid Build Coastguard Worker // create tensor handles
2236*89c4ff92SAndroid Build Coastguard Worker CreateTensorHandles(graph, factory);
2237*89c4ff92SAndroid Build Coastguard Worker
2238*89c4ff92SAndroid Build Coastguard Worker // create Constant workload"
2239*89c4ff92SAndroid Build Coastguard Worker auto workloadConstant = MakeAndCheckWorkload<ConstantWorkload>(*constant, factory);
2240*89c4ff92SAndroid Build Coastguard Worker CHECK(workloadConstant);
2241*89c4ff92SAndroid Build Coastguard Worker
2242*89c4ff92SAndroid Build Coastguard Worker return workloadConstant;
2243*89c4ff92SAndroid Build Coastguard Worker }
2244*89c4ff92SAndroid Build Coastguard Worker
2245*89c4ff92SAndroid Build Coastguard Worker template <typename PreluWorkload>
CreatePreluWorkloadTest(armnn::IWorkloadFactory & factory,armnn::Graph & graph,const armnn::TensorShape & inputShape,const armnn::TensorShape & alphaShape,const armnn::TensorShape & outputShape,armnn::DataType dataType)2246*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<PreluWorkload> CreatePreluWorkloadTest(armnn::IWorkloadFactory& factory,
2247*89c4ff92SAndroid Build Coastguard Worker armnn::Graph& graph,
2248*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorShape& inputShape,
2249*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorShape& alphaShape,
2250*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorShape& outputShape,
2251*89c4ff92SAndroid Build Coastguard Worker armnn::DataType dataType)
2252*89c4ff92SAndroid Build Coastguard Worker {
2253*89c4ff92SAndroid Build Coastguard Worker // Creates the PReLU layer
2254*89c4ff92SAndroid Build Coastguard Worker Layer* const layer = graph.AddLayer<PreluLayer>("prelu");
2255*89c4ff92SAndroid Build Coastguard Worker CHECK(layer != nullptr);
2256*89c4ff92SAndroid Build Coastguard Worker
2257*89c4ff92SAndroid Build Coastguard Worker // Creates extra layers
2258*89c4ff92SAndroid Build Coastguard Worker Layer* const input = graph.AddLayer<InputLayer> (0, "input");
2259*89c4ff92SAndroid Build Coastguard Worker Layer* const alpha = graph.AddLayer<InputLayer> (1, "alpha");
2260*89c4ff92SAndroid Build Coastguard Worker Layer* const output = graph.AddLayer<OutputLayer>(0, "output");
2261*89c4ff92SAndroid Build Coastguard Worker CHECK(input != nullptr);
2262*89c4ff92SAndroid Build Coastguard Worker CHECK(alpha != nullptr);
2263*89c4ff92SAndroid Build Coastguard Worker CHECK(output != nullptr);
2264*89c4ff92SAndroid Build Coastguard Worker
2265*89c4ff92SAndroid Build Coastguard Worker // Connects up
2266*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputTensorInfo (inputShape, dataType);
2267*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo alphaTensorInfo (alphaShape, dataType);
2268*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo(outputShape, dataType);
2269*89c4ff92SAndroid Build Coastguard Worker Connect(input, layer, inputTensorInfo, 0, 0);
2270*89c4ff92SAndroid Build Coastguard Worker Connect(alpha, layer, alphaTensorInfo, 0, 1);
2271*89c4ff92SAndroid Build Coastguard Worker Connect(layer, output, outputTensorInfo, 0, 0);
2272*89c4ff92SAndroid Build Coastguard Worker CreateTensorHandles(graph, factory);
2273*89c4ff92SAndroid Build Coastguard Worker
2274*89c4ff92SAndroid Build Coastguard Worker // Makes the workload and checks it
2275*89c4ff92SAndroid Build Coastguard Worker auto workload = MakeAndCheckWorkload<PreluWorkload>(*layer, factory);
2276*89c4ff92SAndroid Build Coastguard Worker
2277*89c4ff92SAndroid Build Coastguard Worker PreluQueueDescriptor queueDescriptor = workload->GetData();
2278*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Inputs.size() == 2);
2279*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Outputs.size() == 1);
2280*89c4ff92SAndroid Build Coastguard Worker
2281*89c4ff92SAndroid Build Coastguard Worker // Returns so we can do extra, backend-specific tests.
2282*89c4ff92SAndroid Build Coastguard Worker return workload;
2283*89c4ff92SAndroid Build Coastguard Worker }
2284*89c4ff92SAndroid Build Coastguard Worker
2285*89c4ff92SAndroid Build Coastguard Worker template <typename SpaceToDepthWorkload, armnn::DataType DataType>
CreateSpaceToDepthWorkloadTest(armnn::IWorkloadFactory & factory,armnn::Graph & graph)2286*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<SpaceToDepthWorkload> CreateSpaceToDepthWorkloadTest(armnn::IWorkloadFactory& factory,
2287*89c4ff92SAndroid Build Coastguard Worker armnn::Graph& graph)
2288*89c4ff92SAndroid Build Coastguard Worker {
2289*89c4ff92SAndroid Build Coastguard Worker SpaceToDepthDescriptor desc;
2290*89c4ff92SAndroid Build Coastguard Worker desc.m_BlockSize = 2;
2291*89c4ff92SAndroid Build Coastguard Worker Layer* const layer = graph.AddLayer<SpaceToDepthLayer>(desc, "spaceToDepth");
2292*89c4ff92SAndroid Build Coastguard Worker
2293*89c4ff92SAndroid Build Coastguard Worker // Creates extra layers.
2294*89c4ff92SAndroid Build Coastguard Worker Layer* const input = graph.AddLayer<InputLayer>(0, "input");
2295*89c4ff92SAndroid Build Coastguard Worker Layer* const output = graph.AddLayer<OutputLayer>(0, "output");
2296*89c4ff92SAndroid Build Coastguard Worker
2297*89c4ff92SAndroid Build Coastguard Worker // Connects up.
2298*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputTensorInfo({ 1, 2, 2, 1 }, DataType);
2299*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo({ 1, 1, 1, 4 }, DataType);
2300*89c4ff92SAndroid Build Coastguard Worker
2301*89c4ff92SAndroid Build Coastguard Worker Connect(input, layer, inputTensorInfo);
2302*89c4ff92SAndroid Build Coastguard Worker Connect(layer, output, outputTensorInfo);
2303*89c4ff92SAndroid Build Coastguard Worker
2304*89c4ff92SAndroid Build Coastguard Worker CreateTensorHandles(graph, factory);
2305*89c4ff92SAndroid Build Coastguard Worker
2306*89c4ff92SAndroid Build Coastguard Worker // Makes the workload and checks it.
2307*89c4ff92SAndroid Build Coastguard Worker auto workload = MakeAndCheckWorkload<SpaceToDepthWorkload>(*layer, factory);
2308*89c4ff92SAndroid Build Coastguard Worker
2309*89c4ff92SAndroid Build Coastguard Worker SpaceToDepthQueueDescriptor queueDescriptor = workload->GetData();
2310*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Inputs.size() == 1);
2311*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Outputs.size() == 1);
2312*89c4ff92SAndroid Build Coastguard Worker
2313*89c4ff92SAndroid Build Coastguard Worker return workload;
2314*89c4ff92SAndroid Build Coastguard Worker }
2315*89c4ff92SAndroid Build Coastguard Worker
2316*89c4ff92SAndroid Build Coastguard Worker template <typename StackWorkload, armnn::DataType DataType>
CreateStackWorkloadTest(armnn::IWorkloadFactory & factory,armnn::Graph & graph,const armnn::TensorShape & inputShape,const armnn::TensorShape & outputShape,unsigned int axis,unsigned int numInputs)2317*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<StackWorkload> CreateStackWorkloadTest(armnn::IWorkloadFactory& factory,
2318*89c4ff92SAndroid Build Coastguard Worker armnn::Graph& graph,
2319*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorShape& inputShape,
2320*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorShape& outputShape,
2321*89c4ff92SAndroid Build Coastguard Worker unsigned int axis,
2322*89c4ff92SAndroid Build Coastguard Worker unsigned int numInputs)
2323*89c4ff92SAndroid Build Coastguard Worker {
2324*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputTensorInfo(inputShape, DataType);
2325*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo(outputShape, DataType);
2326*89c4ff92SAndroid Build Coastguard Worker
2327*89c4ff92SAndroid Build Coastguard Worker // Constructs the Stack layer.
2328*89c4ff92SAndroid Build Coastguard Worker armnn::StackDescriptor descriptor(axis, numInputs, inputShape);
2329*89c4ff92SAndroid Build Coastguard Worker Layer* const stackLayer = graph.AddLayer<StackLayer>(descriptor, "stack");
2330*89c4ff92SAndroid Build Coastguard Worker CHECK(stackLayer != nullptr);
2331*89c4ff92SAndroid Build Coastguard Worker
2332*89c4ff92SAndroid Build Coastguard Worker // Constructs layer inputs and output.
2333*89c4ff92SAndroid Build Coastguard Worker std::vector<Layer*> inputs;
2334*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i=0; i<numInputs; ++i)
2335*89c4ff92SAndroid Build Coastguard Worker {
2336*89c4ff92SAndroid Build Coastguard Worker inputs.push_back(graph.AddLayer<InputLayer>(
2337*89c4ff92SAndroid Build Coastguard Worker static_cast<int>(i),
2338*89c4ff92SAndroid Build Coastguard Worker ("input" + std::to_string(i)).c_str()
2339*89c4ff92SAndroid Build Coastguard Worker ));
2340*89c4ff92SAndroid Build Coastguard Worker CHECK(inputs[i] != nullptr);
2341*89c4ff92SAndroid Build Coastguard Worker }
2342*89c4ff92SAndroid Build Coastguard Worker Layer* const output = graph.AddLayer<OutputLayer>(0, "output");
2343*89c4ff92SAndroid Build Coastguard Worker CHECK(output != nullptr);
2344*89c4ff92SAndroid Build Coastguard Worker
2345*89c4ff92SAndroid Build Coastguard Worker // Adds connections.
2346*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i=0; i<numInputs; ++i)
2347*89c4ff92SAndroid Build Coastguard Worker {
2348*89c4ff92SAndroid Build Coastguard Worker Connect(inputs[i], stackLayer, inputTensorInfo, 0, i);
2349*89c4ff92SAndroid Build Coastguard Worker }
2350*89c4ff92SAndroid Build Coastguard Worker Connect(stackLayer, output, outputTensorInfo, 0, 0);
2351*89c4ff92SAndroid Build Coastguard Worker
2352*89c4ff92SAndroid Build Coastguard Worker CreateTensorHandles(graph, factory);
2353*89c4ff92SAndroid Build Coastguard Worker
2354*89c4ff92SAndroid Build Coastguard Worker auto stackWorkload = MakeAndCheckWorkload<StackWorkload>(*stackLayer, factory);
2355*89c4ff92SAndroid Build Coastguard Worker StackQueueDescriptor queueDescriptor = stackWorkload->GetData();
2356*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Inputs.size() == numInputs);
2357*89c4ff92SAndroid Build Coastguard Worker CHECK(queueDescriptor.m_Outputs.size() == 1);
2358*89c4ff92SAndroid Build Coastguard Worker
2359*89c4ff92SAndroid Build Coastguard Worker return stackWorkload;
2360*89c4ff92SAndroid Build Coastguard Worker }
2361*89c4ff92SAndroid Build Coastguard Worker
2362*89c4ff92SAndroid Build Coastguard Worker } // Anonymous namespace
2363