1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #pragma once
6*89c4ff92SAndroid Build Coastguard Worker
7*89c4ff92SAndroid Build Coastguard Worker #include <CommonTestUtils.hpp>
8*89c4ff92SAndroid Build Coastguard Worker
9*89c4ff92SAndroid Build Coastguard Worker #include <ResolveType.hpp>
10*89c4ff92SAndroid Build Coastguard Worker
11*89c4ff92SAndroid Build Coastguard Worker #include <armnn/INetwork.hpp>
12*89c4ff92SAndroid Build Coastguard Worker
13*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/NumericCast.hpp>
14*89c4ff92SAndroid Build Coastguard Worker
15*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
16*89c4ff92SAndroid Build Coastguard Worker
17*89c4ff92SAndroid Build Coastguard Worker #include <vector>
18*89c4ff92SAndroid Build Coastguard Worker
19*89c4ff92SAndroid Build Coastguard Worker namespace
20*89c4ff92SAndroid Build Coastguard Worker {
21*89c4ff92SAndroid Build Coastguard Worker
CreateFullyConnectedNetworkNonConstWeights(const armnn::TensorInfo & inputTensorInfo,const armnn::TensorInfo & outputTensorInfo,const armnn::TensorInfo & weightsTensorInfo,armnn::FullyConnectedDescriptor descriptor)22*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr CreateFullyConnectedNetworkNonConstWeights(const armnn::TensorInfo& inputTensorInfo,
23*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& outputTensorInfo,
24*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& weightsTensorInfo,
25*89c4ff92SAndroid Build Coastguard Worker armnn::FullyConnectedDescriptor descriptor)
26*89c4ff92SAndroid Build Coastguard Worker {
27*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr network(armnn::INetwork::Create());
28*89c4ff92SAndroid Build Coastguard Worker
29*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* inputLayer = network->AddInputLayer(0, "Input");
30*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* weightsInputLayer = network->AddInputLayer(1, "Weights_Input");
31*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* fullyConnectedLayer = network->AddFullyConnectedLayer(descriptor, "Fully_Connected");
32*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* outputLayer = network->AddOutputLayer(0, "Output");
33*89c4ff92SAndroid Build Coastguard Worker
34*89c4ff92SAndroid Build Coastguard Worker Connect(inputLayer, fullyConnectedLayer, inputTensorInfo, 0, 0);
35*89c4ff92SAndroid Build Coastguard Worker Connect(weightsInputLayer, fullyConnectedLayer, weightsTensorInfo, 0, 1);
36*89c4ff92SAndroid Build Coastguard Worker Connect(fullyConnectedLayer, outputLayer, outputTensorInfo, 0, 0);
37*89c4ff92SAndroid Build Coastguard Worker
38*89c4ff92SAndroid Build Coastguard Worker return network;
39*89c4ff92SAndroid Build Coastguard Worker }
40*89c4ff92SAndroid Build Coastguard Worker
CreateFullyConnectedNetworkNonConstWeightsConstBias(const armnn::TensorInfo & inputTensorInfo,const armnn::TensorInfo & outputTensorInfo,const armnn::TensorInfo & weightsTensorInfo,const armnn::TensorInfo & biasTensorInfo,const armnn::ConstTensor & biasConstantTensor,armnn::FullyConnectedDescriptor descriptor)41*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr CreateFullyConnectedNetworkNonConstWeightsConstBias(const armnn::TensorInfo& inputTensorInfo,
42*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& outputTensorInfo,
43*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& weightsTensorInfo,
44*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& biasTensorInfo,
45*89c4ff92SAndroid Build Coastguard Worker const armnn::ConstTensor& biasConstantTensor,
46*89c4ff92SAndroid Build Coastguard Worker armnn::FullyConnectedDescriptor descriptor)
47*89c4ff92SAndroid Build Coastguard Worker {
48*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr network(armnn::INetwork::Create());
49*89c4ff92SAndroid Build Coastguard Worker
50*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* inputLayer = network->AddInputLayer(0, "Input");
51*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* weightsInputLayer = network->AddInputLayer(1, "Weights_Input");
52*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* biasLayer = network->AddConstantLayer(biasConstantTensor, "Weights");
53*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* fullyConnectedLayer = network->AddFullyConnectedLayer(descriptor, "Fully_Connected");
54*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* outputLayer = network->AddOutputLayer(0, "Output");
55*89c4ff92SAndroid Build Coastguard Worker
56*89c4ff92SAndroid Build Coastguard Worker Connect(inputLayer, fullyConnectedLayer, inputTensorInfo, 0, 0);
57*89c4ff92SAndroid Build Coastguard Worker Connect(weightsInputLayer, fullyConnectedLayer, weightsTensorInfo, 0, 1);
58*89c4ff92SAndroid Build Coastguard Worker Connect(biasLayer, fullyConnectedLayer, biasTensorInfo, 0, 2);
59*89c4ff92SAndroid Build Coastguard Worker Connect(fullyConnectedLayer, outputLayer, outputTensorInfo, 0, 0);
60*89c4ff92SAndroid Build Coastguard Worker
61*89c4ff92SAndroid Build Coastguard Worker return network;
62*89c4ff92SAndroid Build Coastguard Worker }
63*89c4ff92SAndroid Build Coastguard Worker
CreateFullyConnectedNetworkConstWeightsNonConstBias(const armnn::TensorInfo & inputTensorInfo,const armnn::TensorInfo & outputTensorInfo,const armnn::TensorInfo & weightsTensorInfo,const armnn::TensorInfo & biasTensorInfo,const armnn::ConstTensor & weightsConstantTensor,armnn::FullyConnectedDescriptor descriptor)64*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr CreateFullyConnectedNetworkConstWeightsNonConstBias(const armnn::TensorInfo& inputTensorInfo,
65*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& outputTensorInfo,
66*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& weightsTensorInfo,
67*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& biasTensorInfo,
68*89c4ff92SAndroid Build Coastguard Worker const armnn::ConstTensor& weightsConstantTensor,
69*89c4ff92SAndroid Build Coastguard Worker armnn::FullyConnectedDescriptor descriptor)
70*89c4ff92SAndroid Build Coastguard Worker {
71*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr network(armnn::INetwork::Create());
72*89c4ff92SAndroid Build Coastguard Worker
73*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* inputLayer = network->AddInputLayer(0, "Input");
74*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* weightsLayer = network->AddConstantLayer(weightsConstantTensor, "Weights");
75*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* biasLayer = network->AddInputLayer(2, "Bias_Input");
76*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* fullyConnectedLayer = network->AddFullyConnectedLayer(descriptor, "Fully_Connected");
77*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* outputLayer = network->AddOutputLayer(0, "Output");
78*89c4ff92SAndroid Build Coastguard Worker
79*89c4ff92SAndroid Build Coastguard Worker Connect(inputLayer, fullyConnectedLayer, inputTensorInfo, 0, 0);
80*89c4ff92SAndroid Build Coastguard Worker Connect(weightsLayer, fullyConnectedLayer, weightsTensorInfo, 0, 1);
81*89c4ff92SAndroid Build Coastguard Worker Connect(biasLayer, fullyConnectedLayer, biasTensorInfo, 0, 2);
82*89c4ff92SAndroid Build Coastguard Worker Connect(fullyConnectedLayer, outputLayer, outputTensorInfo, 0, 0);
83*89c4ff92SAndroid Build Coastguard Worker
84*89c4ff92SAndroid Build Coastguard Worker return network;
85*89c4ff92SAndroid Build Coastguard Worker }
86*89c4ff92SAndroid Build Coastguard Worker
CreateFullyConnectedNetworkNoTensorInfoConstWeights(const armnn::TensorInfo & inputTensorInfo,const armnn::TensorInfo & outputTensorInfo,const armnn::ConstTensor & weightsConstantTensor,armnn::FullyConnectedDescriptor descriptor)87*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr CreateFullyConnectedNetworkNoTensorInfoConstWeights(const armnn::TensorInfo& inputTensorInfo,
88*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& outputTensorInfo,
89*89c4ff92SAndroid Build Coastguard Worker const armnn::ConstTensor& weightsConstantTensor,
90*89c4ff92SAndroid Build Coastguard Worker armnn::FullyConnectedDescriptor descriptor)
91*89c4ff92SAndroid Build Coastguard Worker {
92*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr network(armnn::INetwork::Create());
93*89c4ff92SAndroid Build Coastguard Worker
94*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* inputLayer = network->AddInputLayer(0, "Input");
95*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* weightsLayer = network->AddConstantLayer(weightsConstantTensor, "Weights");
96*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* fullyConnectedLayer = network->AddFullyConnectedLayer(descriptor, "Fully_Connected");
97*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* outputLayer = network->AddOutputLayer(0, "Output");
98*89c4ff92SAndroid Build Coastguard Worker
99*89c4ff92SAndroid Build Coastguard Worker Connect(inputLayer, fullyConnectedLayer, inputTensorInfo, 0, 0);
100*89c4ff92SAndroid Build Coastguard Worker weightsLayer->GetOutputSlot(0).Connect(fullyConnectedLayer->GetInputSlot(1));
101*89c4ff92SAndroid Build Coastguard Worker Connect(fullyConnectedLayer, outputLayer, outputTensorInfo, 0, 0);
102*89c4ff92SAndroid Build Coastguard Worker
103*89c4ff92SAndroid Build Coastguard Worker return network;
104*89c4ff92SAndroid Build Coastguard Worker }
105*89c4ff92SAndroid Build Coastguard Worker
CreateFullyConnectedNetworkNoConnectedWeightsExplicit(const armnn::TensorInfo & inputTensorInfo,const armnn::TensorInfo & outputTensorInfo,const armnn::TensorInfo & biasTensorInfo,armnn::FullyConnectedDescriptor descriptor)106*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr CreateFullyConnectedNetworkNoConnectedWeightsExplicit(const armnn::TensorInfo& inputTensorInfo,
107*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& outputTensorInfo,
108*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& biasTensorInfo,
109*89c4ff92SAndroid Build Coastguard Worker armnn::FullyConnectedDescriptor descriptor)
110*89c4ff92SAndroid Build Coastguard Worker {
111*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr network(armnn::INetwork::Create());
112*89c4ff92SAndroid Build Coastguard Worker
113*89c4ff92SAndroid Build Coastguard Worker
114*89c4ff92SAndroid Build Coastguard Worker ConstTensor biases;
115*89c4ff92SAndroid Build Coastguard Worker
116*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* inputLayer = network->AddInputLayer(0, "Input");
117*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* biasLayer = network->AddConstantLayer(biases, "Bias_Input");
118*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* fullyConnectedLayer = network->AddFullyConnectedLayer(descriptor, "Fully_Connected");
119*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* outputLayer = network->AddOutputLayer(0, "Output");
120*89c4ff92SAndroid Build Coastguard Worker
121*89c4ff92SAndroid Build Coastguard Worker Connect(inputLayer, fullyConnectedLayer, inputTensorInfo, 0, 0);
122*89c4ff92SAndroid Build Coastguard Worker Connect(biasLayer, fullyConnectedLayer, biasTensorInfo, 0, 2);
123*89c4ff92SAndroid Build Coastguard Worker Connect(fullyConnectedLayer, outputLayer, outputTensorInfo, 0, 0);
124*89c4ff92SAndroid Build Coastguard Worker
125*89c4ff92SAndroid Build Coastguard Worker return network;
126*89c4ff92SAndroid Build Coastguard Worker }
127*89c4ff92SAndroid Build Coastguard Worker
CreateFullyConnectedNetworkNoConnectedWeightsAndBias(const armnn::TensorInfo & inputTensorInfo,const armnn::TensorInfo & outputTensorInfo,armnn::FullyConnectedDescriptor descriptor)128*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr CreateFullyConnectedNetworkNoConnectedWeightsAndBias(const armnn::TensorInfo& inputTensorInfo,
129*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& outputTensorInfo,
130*89c4ff92SAndroid Build Coastguard Worker armnn::FullyConnectedDescriptor descriptor)
131*89c4ff92SAndroid Build Coastguard Worker {
132*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr network(armnn::INetwork::Create());
133*89c4ff92SAndroid Build Coastguard Worker
134*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* inputLayer = network->AddInputLayer(0, "Input");
135*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* fullyConnectedLayer = network->AddFullyConnectedLayer(descriptor, "Fully_Connected");
136*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* outputLayer = network->AddOutputLayer(0, "Output");
137*89c4ff92SAndroid Build Coastguard Worker
138*89c4ff92SAndroid Build Coastguard Worker Connect(inputLayer, fullyConnectedLayer, inputTensorInfo, 0, 0);
139*89c4ff92SAndroid Build Coastguard Worker Connect(fullyConnectedLayer, outputLayer, outputTensorInfo, 0, 0);
140*89c4ff92SAndroid Build Coastguard Worker
141*89c4ff92SAndroid Build Coastguard Worker return network;
142*89c4ff92SAndroid Build Coastguard Worker }
143*89c4ff92SAndroid Build Coastguard Worker
CreateFullyConnectedNetworkNoConnectedBiasExplicit(const armnn::TensorInfo & inputTensorInfo,const armnn::TensorInfo & outputTensorInfo,const armnn::TensorInfo & weightsTensorInfo,const armnn::ConstTensor & weightsConstantTensor,armnn::FullyConnectedDescriptor descriptor)144*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr CreateFullyConnectedNetworkNoConnectedBiasExplicit(const armnn::TensorInfo& inputTensorInfo,
145*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& outputTensorInfo,
146*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& weightsTensorInfo,
147*89c4ff92SAndroid Build Coastguard Worker const armnn::ConstTensor& weightsConstantTensor,
148*89c4ff92SAndroid Build Coastguard Worker armnn::FullyConnectedDescriptor descriptor)
149*89c4ff92SAndroid Build Coastguard Worker {
150*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr network(armnn::INetwork::Create());
151*89c4ff92SAndroid Build Coastguard Worker
152*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* inputLayer = network->AddInputLayer(0, "Input");
153*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* weightsLayer = network->AddConstantLayer(weightsConstantTensor, "Weights");
154*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* fullyConnectedLayer = network->AddFullyConnectedLayer(descriptor, "Fully_Connected");
155*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* outputLayer = network->AddOutputLayer(0, "Output");
156*89c4ff92SAndroid Build Coastguard Worker
157*89c4ff92SAndroid Build Coastguard Worker Connect(inputLayer, fullyConnectedLayer, inputTensorInfo, 0, 0);
158*89c4ff92SAndroid Build Coastguard Worker Connect(weightsLayer, fullyConnectedLayer, weightsTensorInfo, 0, 1);
159*89c4ff92SAndroid Build Coastguard Worker Connect(fullyConnectedLayer, outputLayer, outputTensorInfo, 0, 0);
160*89c4ff92SAndroid Build Coastguard Worker
161*89c4ff92SAndroid Build Coastguard Worker return network;
162*89c4ff92SAndroid Build Coastguard Worker }
163*89c4ff92SAndroid Build Coastguard Worker
164*89c4ff92SAndroid Build Coastguard Worker template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
FullyConnectedWithDynamicWeightsEndToEnd(const std::vector<armnn::BackendId> & backends)165*89c4ff92SAndroid Build Coastguard Worker void FullyConnectedWithDynamicWeightsEndToEnd(const std::vector<armnn::BackendId>& backends)
166*89c4ff92SAndroid Build Coastguard Worker {
167*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
168*89c4ff92SAndroid Build Coastguard Worker
169*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputTensorInfo({ 1, 1, 2, 3 }, ArmnnType);
170*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo.SetQuantizationScale(0.1f);
171*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo.SetQuantizationOffset(63);
172*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo.SetConstant(true);
173*89c4ff92SAndroid Build Coastguard Worker
174*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo({ 1, 2 }, ArmnnType);
175*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo.SetQuantizationScale(5.f);
176*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo.SetQuantizationOffset(10);
177*89c4ff92SAndroid Build Coastguard Worker
178*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo weightsTensorInfo({ 2, 6 }, ArmnnType);
179*89c4ff92SAndroid Build Coastguard Worker weightsTensorInfo.SetQuantizationScale(0.2f);
180*89c4ff92SAndroid Build Coastguard Worker weightsTensorInfo.SetQuantizationOffset(93);
181*89c4ff92SAndroid Build Coastguard Worker weightsTensorInfo.SetConstant(true);
182*89c4ff92SAndroid Build Coastguard Worker
183*89c4ff92SAndroid Build Coastguard Worker FullyConnectedDescriptor descriptor;
184*89c4ff92SAndroid Build Coastguard Worker descriptor.m_ConstantWeights = false;
185*89c4ff92SAndroid Build Coastguard Worker descriptor.m_BiasEnabled = false;
186*89c4ff92SAndroid Build Coastguard Worker descriptor.m_TransposeWeightMatrix = true;
187*89c4ff92SAndroid Build Coastguard Worker
188*89c4ff92SAndroid Build Coastguard Worker std::vector<T> inputData {
189*89c4ff92SAndroid Build Coastguard Worker -1.2f, 6.1f, -3.5f,
190*89c4ff92SAndroid Build Coastguard Worker 18.8f, -5.5f, 2.9f
191*89c4ff92SAndroid Build Coastguard Worker };
192*89c4ff92SAndroid Build Coastguard Worker
193*89c4ff92SAndroid Build Coastguard Worker std::vector<T> weightsData {
194*89c4ff92SAndroid Build Coastguard Worker -8.4f, 20.0f, -10.4f, -8, 16.4f, -11.8f,
195*89c4ff92SAndroid Build Coastguard Worker 23.4f, 10.4f, -14.0f, -3.8f, -11.8f, 11.4f
196*89c4ff92SAndroid Build Coastguard Worker };
197*89c4ff92SAndroid Build Coastguard Worker
198*89c4ff92SAndroid Build Coastguard Worker std::vector<T> floatExpectedOutputData {
199*89c4ff92SAndroid Build Coastguard Worker -107.04f, 110.f
200*89c4ff92SAndroid Build Coastguard Worker };
201*89c4ff92SAndroid Build Coastguard Worker std::vector<T> expectedOutputData = armnnUtils::QuantizedVector<T>(floatExpectedOutputData);
202*89c4ff92SAndroid Build Coastguard Worker
203*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr network = CreateFullyConnectedNetworkNonConstWeights(inputTensorInfo,
204*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo,
205*89c4ff92SAndroid Build Coastguard Worker weightsTensorInfo,
206*89c4ff92SAndroid Build Coastguard Worker descriptor);
207*89c4ff92SAndroid Build Coastguard Worker
208*89c4ff92SAndroid Build Coastguard Worker CHECK(network);
209*89c4ff92SAndroid Build Coastguard Worker
210*89c4ff92SAndroid Build Coastguard Worker std::map<int, std::vector<T>> inputTensorData = {{ 0, inputData }, {1, weightsData}};
211*89c4ff92SAndroid Build Coastguard Worker std::map<int, std::vector<T>> expectedOutputTensorData = {{ 0, expectedOutputData }};
212*89c4ff92SAndroid Build Coastguard Worker
213*89c4ff92SAndroid Build Coastguard Worker EndToEndLayerTestImpl<ArmnnType, ArmnnType>(move(network),
214*89c4ff92SAndroid Build Coastguard Worker inputTensorData,
215*89c4ff92SAndroid Build Coastguard Worker expectedOutputTensorData,
216*89c4ff92SAndroid Build Coastguard Worker backends,
217*89c4ff92SAndroid Build Coastguard Worker 1.0f);
218*89c4ff92SAndroid Build Coastguard Worker }
219*89c4ff92SAndroid Build Coastguard Worker
220*89c4ff92SAndroid Build Coastguard Worker template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
FullyConnectedWithDynamicOrConstantInputsEndToEnd(const std::vector<armnn::BackendId> & backends,const bool transposeWeights,const bool constantWeightsOrBias)221*89c4ff92SAndroid Build Coastguard Worker void FullyConnectedWithDynamicOrConstantInputsEndToEnd(const std::vector<armnn::BackendId>& backends,
222*89c4ff92SAndroid Build Coastguard Worker const bool transposeWeights,
223*89c4ff92SAndroid Build Coastguard Worker const bool constantWeightsOrBias)
224*89c4ff92SAndroid Build Coastguard Worker {
225*89c4ff92SAndroid Build Coastguard Worker unsigned int inputWidth = 1;
226*89c4ff92SAndroid Build Coastguard Worker unsigned int inputHeight = 1;
227*89c4ff92SAndroid Build Coastguard Worker unsigned int inputChannels = 5;
228*89c4ff92SAndroid Build Coastguard Worker unsigned int inputNum = 2;
229*89c4ff92SAndroid Build Coastguard Worker
230*89c4ff92SAndroid Build Coastguard Worker unsigned int outputChannels = 3;
231*89c4ff92SAndroid Build Coastguard Worker unsigned int outputNum = 2;
232*89c4ff92SAndroid Build Coastguard Worker
233*89c4ff92SAndroid Build Coastguard Worker unsigned int inputShape[] = { inputNum, inputChannels, inputHeight, inputWidth };
234*89c4ff92SAndroid Build Coastguard Worker unsigned int outputShape[] = { outputNum, outputChannels };
235*89c4ff92SAndroid Build Coastguard Worker unsigned int weightsShape[] = { inputChannels, outputChannels };
236*89c4ff92SAndroid Build Coastguard Worker
237*89c4ff92SAndroid Build Coastguard Worker if (transposeWeights)
238*89c4ff92SAndroid Build Coastguard Worker {
239*89c4ff92SAndroid Build Coastguard Worker std::swap(weightsShape[0], weightsShape[1]);
240*89c4ff92SAndroid Build Coastguard Worker }
241*89c4ff92SAndroid Build Coastguard Worker
242*89c4ff92SAndroid Build Coastguard Worker unsigned int biasShape[] = { outputChannels };
243*89c4ff92SAndroid Build Coastguard Worker
244*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputTensorInfo = armnn::TensorInfo(4, inputShape, armnn::DataType::Float32, 0.0f, 0, true);
245*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo = armnn::TensorInfo(2, outputShape, armnn::DataType::Float32);
246*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo weightsDesc = armnn::TensorInfo(2, weightsShape, armnn::DataType::Float32, 0.0f, 0, true);
247*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo biasesDesc = armnn::TensorInfo(1, biasShape, armnn::DataType::Float32, 0.0f, 0, true);
248*89c4ff92SAndroid Build Coastguard Worker
249*89c4ff92SAndroid Build Coastguard Worker std::vector<float> input =
250*89c4ff92SAndroid Build Coastguard Worker {
251*89c4ff92SAndroid Build Coastguard Worker 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
252*89c4ff92SAndroid Build Coastguard Worker 5.0f, 4.0f, 3.0f, 2.0f, 1.0f
253*89c4ff92SAndroid Build Coastguard Worker };
254*89c4ff92SAndroid Build Coastguard Worker
255*89c4ff92SAndroid Build Coastguard Worker std::vector<float> weights =
256*89c4ff92SAndroid Build Coastguard Worker {
257*89c4ff92SAndroid Build Coastguard Worker .5f, 2.f, .5f,
258*89c4ff92SAndroid Build Coastguard Worker .5f, 2.f, 1.f,
259*89c4ff92SAndroid Build Coastguard Worker .5f, 2.f, 2.f,
260*89c4ff92SAndroid Build Coastguard Worker .5f, 2.f, 3.f,
261*89c4ff92SAndroid Build Coastguard Worker .5f, 2.f, 4.f
262*89c4ff92SAndroid Build Coastguard Worker };
263*89c4ff92SAndroid Build Coastguard Worker
264*89c4ff92SAndroid Build Coastguard Worker if (transposeWeights)
265*89c4ff92SAndroid Build Coastguard Worker {
266*89c4ff92SAndroid Build Coastguard Worker weights =
267*89c4ff92SAndroid Build Coastguard Worker {
268*89c4ff92SAndroid Build Coastguard Worker .5f, .5f, .5f, .5f, .5f,
269*89c4ff92SAndroid Build Coastguard Worker 2.f, 2.f, 2.f, 2.f, 2.f,
270*89c4ff92SAndroid Build Coastguard Worker .5f, 1.f, 2.f, 3.f, 4.f
271*89c4ff92SAndroid Build Coastguard Worker };
272*89c4ff92SAndroid Build Coastguard Worker }
273*89c4ff92SAndroid Build Coastguard Worker
274*89c4ff92SAndroid Build Coastguard Worker std::vector<float> biasValues = std::vector<float>({10.f, 20.f, 30.f});
275*89c4ff92SAndroid Build Coastguard Worker
276*89c4ff92SAndroid Build Coastguard Worker std::vector<float> expectedOutput =
277*89c4ff92SAndroid Build Coastguard Worker {
278*89c4ff92SAndroid Build Coastguard Worker 0.5f + 1.0f + 1.5f + 2.0f + 2.5f + biasValues[0],
279*89c4ff92SAndroid Build Coastguard Worker 2.0f + 4.0f + 6.0f + 8.0f + 10.f + biasValues[1],
280*89c4ff92SAndroid Build Coastguard Worker 0.5f + 2.0f + 6.0f + 12.f + 20.f + biasValues[2],
281*89c4ff92SAndroid Build Coastguard Worker
282*89c4ff92SAndroid Build Coastguard Worker 2.5f + 2.0f + 1.5f + 1.0f + 0.5f + biasValues[0],
283*89c4ff92SAndroid Build Coastguard Worker 10.0f + 8.0f + 6.0f + 4.0f + 2.f + biasValues[1],
284*89c4ff92SAndroid Build Coastguard Worker 2.5f + 4.0f + 6.0f + 6.f + 4.f + biasValues[2]
285*89c4ff92SAndroid Build Coastguard Worker };
286*89c4ff92SAndroid Build Coastguard Worker
287*89c4ff92SAndroid Build Coastguard Worker FullyConnectedDescriptor descriptor;
288*89c4ff92SAndroid Build Coastguard Worker descriptor.m_BiasEnabled = true;
289*89c4ff92SAndroid Build Coastguard Worker descriptor.m_TransposeWeightMatrix = transposeWeights;
290*89c4ff92SAndroid Build Coastguard Worker descriptor.m_ConstantWeights = constantWeightsOrBias;
291*89c4ff92SAndroid Build Coastguard Worker
292*89c4ff92SAndroid Build Coastguard Worker if (!constantWeightsOrBias)
293*89c4ff92SAndroid Build Coastguard Worker {
294*89c4ff92SAndroid Build Coastguard Worker // Tests non constant weights and constant bias.
295*89c4ff92SAndroid Build Coastguard Worker ConstTensor biasConstantTensor(biasesDesc, biasValues.data());
296*89c4ff92SAndroid Build Coastguard Worker
297*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr network = CreateFullyConnectedNetworkNonConstWeightsConstBias(inputTensorInfo,
298*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo,
299*89c4ff92SAndroid Build Coastguard Worker weightsDesc,
300*89c4ff92SAndroid Build Coastguard Worker biasesDesc,
301*89c4ff92SAndroid Build Coastguard Worker biasConstantTensor,
302*89c4ff92SAndroid Build Coastguard Worker descriptor);
303*89c4ff92SAndroid Build Coastguard Worker CHECK(network);
304*89c4ff92SAndroid Build Coastguard Worker
305*89c4ff92SAndroid Build Coastguard Worker std::map<int, std::vector<T>> inputTensorData = {{ 0, input }, {1, weights}};
306*89c4ff92SAndroid Build Coastguard Worker std::map<int, std::vector<T>> expectedOutputTensorData = {{ 0, expectedOutput }};
307*89c4ff92SAndroid Build Coastguard Worker
308*89c4ff92SAndroid Build Coastguard Worker EndToEndLayerTestImpl<ArmnnType, ArmnnType>(move(network),
309*89c4ff92SAndroid Build Coastguard Worker inputTensorData,
310*89c4ff92SAndroid Build Coastguard Worker expectedOutputTensorData,
311*89c4ff92SAndroid Build Coastguard Worker backends,
312*89c4ff92SAndroid Build Coastguard Worker 1.0f);
313*89c4ff92SAndroid Build Coastguard Worker }
314*89c4ff92SAndroid Build Coastguard Worker else
315*89c4ff92SAndroid Build Coastguard Worker {
316*89c4ff92SAndroid Build Coastguard Worker // Tests constant weights and non constant bias.
317*89c4ff92SAndroid Build Coastguard Worker ConstTensor weightsConstantTensor(weightsDesc, weights.data());
318*89c4ff92SAndroid Build Coastguard Worker
319*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr network = CreateFullyConnectedNetworkConstWeightsNonConstBias(inputTensorInfo,
320*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo,
321*89c4ff92SAndroid Build Coastguard Worker weightsDesc,
322*89c4ff92SAndroid Build Coastguard Worker biasesDesc,
323*89c4ff92SAndroid Build Coastguard Worker weightsConstantTensor,
324*89c4ff92SAndroid Build Coastguard Worker descriptor);
325*89c4ff92SAndroid Build Coastguard Worker CHECK(network);
326*89c4ff92SAndroid Build Coastguard Worker
327*89c4ff92SAndroid Build Coastguard Worker std::map<int, std::vector<T>> inputTensorData = {{ 0, input }, {2, biasValues}};
328*89c4ff92SAndroid Build Coastguard Worker std::map<int, std::vector<T>> expectedOutputTensorData = {{ 0, expectedOutput }};
329*89c4ff92SAndroid Build Coastguard Worker
330*89c4ff92SAndroid Build Coastguard Worker EndToEndLayerTestImpl<ArmnnType, ArmnnType>(move(network),
331*89c4ff92SAndroid Build Coastguard Worker inputTensorData,
332*89c4ff92SAndroid Build Coastguard Worker expectedOutputTensorData,
333*89c4ff92SAndroid Build Coastguard Worker backends,
334*89c4ff92SAndroid Build Coastguard Worker 1.0f);
335*89c4ff92SAndroid Build Coastguard Worker }
336*89c4ff92SAndroid Build Coastguard Worker }
337*89c4ff92SAndroid Build Coastguard Worker
338*89c4ff92SAndroid Build Coastguard Worker template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
FullyConnectedErrorChecking(const std::vector<armnn::BackendId> & backends,const bool explicitCheck,const bool biasEnabled,const bool connectedWeights,const bool connectedBias,const bool tensorInfoSet)339*89c4ff92SAndroid Build Coastguard Worker void FullyConnectedErrorChecking(const std::vector<armnn::BackendId>& backends,
340*89c4ff92SAndroid Build Coastguard Worker const bool explicitCheck,
341*89c4ff92SAndroid Build Coastguard Worker const bool biasEnabled,
342*89c4ff92SAndroid Build Coastguard Worker const bool connectedWeights,
343*89c4ff92SAndroid Build Coastguard Worker const bool connectedBias,
344*89c4ff92SAndroid Build Coastguard Worker const bool tensorInfoSet)
345*89c4ff92SAndroid Build Coastguard Worker {
346*89c4ff92SAndroid Build Coastguard Worker unsigned int inputWidth = 1;
347*89c4ff92SAndroid Build Coastguard Worker unsigned int inputHeight = 1;
348*89c4ff92SAndroid Build Coastguard Worker unsigned int inputChannels = 5;
349*89c4ff92SAndroid Build Coastguard Worker unsigned int inputNum = 2;
350*89c4ff92SAndroid Build Coastguard Worker
351*89c4ff92SAndroid Build Coastguard Worker unsigned int outputChannels = 3;
352*89c4ff92SAndroid Build Coastguard Worker unsigned int outputNum = 2;
353*89c4ff92SAndroid Build Coastguard Worker
354*89c4ff92SAndroid Build Coastguard Worker unsigned int inputShape[] = { inputNum, inputChannels, inputHeight, inputWidth };
355*89c4ff92SAndroid Build Coastguard Worker unsigned int outputShape[] = { outputNum, outputChannels };
356*89c4ff92SAndroid Build Coastguard Worker unsigned int weightsShape[] = { inputChannels, outputChannels };
357*89c4ff92SAndroid Build Coastguard Worker
358*89c4ff92SAndroid Build Coastguard Worker unsigned int biasShape[] = { outputChannels };
359*89c4ff92SAndroid Build Coastguard Worker
360*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputTensorInfo = armnn::TensorInfo(4, inputShape, armnn::DataType::Float32, 0.0f, 0, true);
361*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo = armnn::TensorInfo(2, outputShape, armnn::DataType::Float32);
362*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo weightsDesc = armnn::TensorInfo(2, weightsShape, armnn::DataType::Float32, 0.0f, 0, true);
363*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo biasesDesc = armnn::TensorInfo(1, biasShape, armnn::DataType::Float32, 0.0f, 0, true);
364*89c4ff92SAndroid Build Coastguard Worker
365*89c4ff92SAndroid Build Coastguard Worker std::vector<float> weights =
366*89c4ff92SAndroid Build Coastguard Worker {
367*89c4ff92SAndroid Build Coastguard Worker .5f, 2.f, .5f,
368*89c4ff92SAndroid Build Coastguard Worker .5f, 2.f, 1.f,
369*89c4ff92SAndroid Build Coastguard Worker .5f, 2.f, 2.f,
370*89c4ff92SAndroid Build Coastguard Worker .5f, 2.f, 3.f,
371*89c4ff92SAndroid Build Coastguard Worker .5f, 2.f, 4.f
372*89c4ff92SAndroid Build Coastguard Worker };
373*89c4ff92SAndroid Build Coastguard Worker
374*89c4ff92SAndroid Build Coastguard Worker FullyConnectedDescriptor descriptor;
375*89c4ff92SAndroid Build Coastguard Worker descriptor.m_BiasEnabled = biasEnabled;
376*89c4ff92SAndroid Build Coastguard Worker
377*89c4ff92SAndroid Build Coastguard Worker if(explicitCheck)
378*89c4ff92SAndroid Build Coastguard Worker {
379*89c4ff92SAndroid Build Coastguard Worker if(!biasEnabled)
380*89c4ff92SAndroid Build Coastguard Worker {
381*89c4ff92SAndroid Build Coastguard Worker try
382*89c4ff92SAndroid Build Coastguard Worker {
383*89c4ff92SAndroid Build Coastguard Worker CreateFullyConnectedNetworkNoConnectedWeightsExplicit(inputTensorInfo,
384*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo,
385*89c4ff92SAndroid Build Coastguard Worker biasesDesc,
386*89c4ff92SAndroid Build Coastguard Worker descriptor);
387*89c4ff92SAndroid Build Coastguard Worker FAIL("LayerValidationException should have been thrown");
388*89c4ff92SAndroid Build Coastguard Worker }
389*89c4ff92SAndroid Build Coastguard Worker catch (const LayerValidationException& exc)
390*89c4ff92SAndroid Build Coastguard Worker {
391*89c4ff92SAndroid Build Coastguard Worker CHECK(strcmp(exc.what(), "Tried to connect bias to FullyConnected layer when bias is not enabled: "
392*89c4ff92SAndroid Build Coastguard Worker "Failed to connect to input slot 2 on FullyConnected layer "
393*89c4ff92SAndroid Build Coastguard Worker "\"Fully_Connected\" as the slot does not exist or is unavailable") == 0);
394*89c4ff92SAndroid Build Coastguard Worker }
395*89c4ff92SAndroid Build Coastguard Worker }
396*89c4ff92SAndroid Build Coastguard Worker else if (!connectedWeights)
397*89c4ff92SAndroid Build Coastguard Worker {
398*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr network = CreateFullyConnectedNetworkNoConnectedWeightsExplicit(inputTensorInfo,
399*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo,
400*89c4ff92SAndroid Build Coastguard Worker biasesDesc,
401*89c4ff92SAndroid Build Coastguard Worker descriptor);
402*89c4ff92SAndroid Build Coastguard Worker CHECK(network);
403*89c4ff92SAndroid Build Coastguard Worker
404*89c4ff92SAndroid Build Coastguard Worker // Create runtime in which test will run
405*89c4ff92SAndroid Build Coastguard Worker IRuntime::CreationOptions options;
406*89c4ff92SAndroid Build Coastguard Worker IRuntimePtr runtime(IRuntime::Create(options));
407*89c4ff92SAndroid Build Coastguard Worker
408*89c4ff92SAndroid Build Coastguard Worker CHECK_THROWS_AS(Optimize(*network, backends, runtime->GetDeviceSpec()), LayerValidationException);
409*89c4ff92SAndroid Build Coastguard Worker }
410*89c4ff92SAndroid Build Coastguard Worker else if (!connectedBias)
411*89c4ff92SAndroid Build Coastguard Worker {
412*89c4ff92SAndroid Build Coastguard Worker // Tests with constant weights.
413*89c4ff92SAndroid Build Coastguard Worker ConstTensor weightsConstantTensor(weightsDesc, weights.data());
414*89c4ff92SAndroid Build Coastguard Worker
415*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr network = CreateFullyConnectedNetworkNoConnectedBiasExplicit(inputTensorInfo,
416*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo,
417*89c4ff92SAndroid Build Coastguard Worker weightsDesc,
418*89c4ff92SAndroid Build Coastguard Worker weightsConstantTensor,
419*89c4ff92SAndroid Build Coastguard Worker descriptor);
420*89c4ff92SAndroid Build Coastguard Worker CHECK(network);
421*89c4ff92SAndroid Build Coastguard Worker
422*89c4ff92SAndroid Build Coastguard Worker // Create runtime in which test will run
423*89c4ff92SAndroid Build Coastguard Worker IRuntime::CreationOptions options;
424*89c4ff92SAndroid Build Coastguard Worker IRuntimePtr runtime(IRuntime::Create(options));
425*89c4ff92SAndroid Build Coastguard Worker
426*89c4ff92SAndroid Build Coastguard Worker CHECK_THROWS_AS(Optimize(*network, backends, runtime->GetDeviceSpec()), LayerValidationException);
427*89c4ff92SAndroid Build Coastguard Worker }
428*89c4ff92SAndroid Build Coastguard Worker }
429*89c4ff92SAndroid Build Coastguard Worker else if(!connectedWeights && !connectedBias)
430*89c4ff92SAndroid Build Coastguard Worker {
431*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr network = CreateFullyConnectedNetworkNoConnectedWeightsAndBias(inputTensorInfo,
432*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo,
433*89c4ff92SAndroid Build Coastguard Worker descriptor);
434*89c4ff92SAndroid Build Coastguard Worker CHECK(network);
435*89c4ff92SAndroid Build Coastguard Worker
436*89c4ff92SAndroid Build Coastguard Worker // Create runtime in which test will run
437*89c4ff92SAndroid Build Coastguard Worker IRuntime::CreationOptions options;
438*89c4ff92SAndroid Build Coastguard Worker IRuntimePtr runtime(IRuntime::Create(options));
439*89c4ff92SAndroid Build Coastguard Worker
440*89c4ff92SAndroid Build Coastguard Worker CHECK_THROWS_AS(Optimize(*network, backends, runtime->GetDeviceSpec()), LayerValidationException);
441*89c4ff92SAndroid Build Coastguard Worker }
442*89c4ff92SAndroid Build Coastguard Worker else if(!tensorInfoSet)
443*89c4ff92SAndroid Build Coastguard Worker {
444*89c4ff92SAndroid Build Coastguard Worker // Tests with constant weights.
445*89c4ff92SAndroid Build Coastguard Worker ConstTensor weightsConstantTensor(weightsDesc, weights.data());
446*89c4ff92SAndroid Build Coastguard Worker
447*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr network = CreateFullyConnectedNetworkNoTensorInfoConstWeights(inputTensorInfo,
448*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo,
449*89c4ff92SAndroid Build Coastguard Worker weightsConstantTensor,
450*89c4ff92SAndroid Build Coastguard Worker descriptor);
451*89c4ff92SAndroid Build Coastguard Worker CHECK(network);
452*89c4ff92SAndroid Build Coastguard Worker
453*89c4ff92SAndroid Build Coastguard Worker // Create runtime in which test will run
454*89c4ff92SAndroid Build Coastguard Worker IRuntime::CreationOptions options;
455*89c4ff92SAndroid Build Coastguard Worker IRuntimePtr runtime(IRuntime::Create(options));
456*89c4ff92SAndroid Build Coastguard Worker
457*89c4ff92SAndroid Build Coastguard Worker try
458*89c4ff92SAndroid Build Coastguard Worker {
459*89c4ff92SAndroid Build Coastguard Worker Optimize(*network, backends, runtime->GetDeviceSpec());
460*89c4ff92SAndroid Build Coastguard Worker FAIL("LayerValidationException should have been thrown");
461*89c4ff92SAndroid Build Coastguard Worker }
462*89c4ff92SAndroid Build Coastguard Worker catch (const LayerValidationException& exc)
463*89c4ff92SAndroid Build Coastguard Worker {
464*89c4ff92SAndroid Build Coastguard Worker CHECK(strcmp(exc.what(), "Output slot TensorInfo not set on Constant layer \"Weights\"") == 0);
465*89c4ff92SAndroid Build Coastguard Worker }
466*89c4ff92SAndroid Build Coastguard Worker }
467*89c4ff92SAndroid Build Coastguard Worker }
468*89c4ff92SAndroid Build Coastguard Worker
469*89c4ff92SAndroid Build Coastguard Worker } // anonymous namespace
470