xref: /aosp_15_r20/external/armnn/src/backends/backendsCommon/test/FullyConnectedEndToEndTestImpl.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #pragma once
6*89c4ff92SAndroid Build Coastguard Worker 
7*89c4ff92SAndroid Build Coastguard Worker #include <CommonTestUtils.hpp>
8*89c4ff92SAndroid Build Coastguard Worker 
9*89c4ff92SAndroid Build Coastguard Worker #include <ResolveType.hpp>
10*89c4ff92SAndroid Build Coastguard Worker 
11*89c4ff92SAndroid Build Coastguard Worker #include <armnn/INetwork.hpp>
12*89c4ff92SAndroid Build Coastguard Worker 
13*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/NumericCast.hpp>
14*89c4ff92SAndroid Build Coastguard Worker 
15*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
16*89c4ff92SAndroid Build Coastguard Worker 
17*89c4ff92SAndroid Build Coastguard Worker #include <vector>
18*89c4ff92SAndroid Build Coastguard Worker 
19*89c4ff92SAndroid Build Coastguard Worker namespace
20*89c4ff92SAndroid Build Coastguard Worker {
21*89c4ff92SAndroid Build Coastguard Worker 
CreateFullyConnectedNetworkNonConstWeights(const armnn::TensorInfo & inputTensorInfo,const armnn::TensorInfo & outputTensorInfo,const armnn::TensorInfo & weightsTensorInfo,armnn::FullyConnectedDescriptor descriptor)22*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr CreateFullyConnectedNetworkNonConstWeights(const armnn::TensorInfo& inputTensorInfo,
23*89c4ff92SAndroid Build Coastguard Worker                                                               const armnn::TensorInfo& outputTensorInfo,
24*89c4ff92SAndroid Build Coastguard Worker                                                               const armnn::TensorInfo& weightsTensorInfo,
25*89c4ff92SAndroid Build Coastguard Worker                                                               armnn::FullyConnectedDescriptor descriptor)
26*89c4ff92SAndroid Build Coastguard Worker {
27*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network(armnn::INetwork::Create());
28*89c4ff92SAndroid Build Coastguard Worker 
29*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* inputLayer  = network->AddInputLayer(0, "Input");
30*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* weightsInputLayer   = network->AddInputLayer(1, "Weights_Input");
31*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* fullyConnectedLayer = network->AddFullyConnectedLayer(descriptor, "Fully_Connected");
32*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* outputLayer = network->AddOutputLayer(0, "Output");
33*89c4ff92SAndroid Build Coastguard Worker 
34*89c4ff92SAndroid Build Coastguard Worker     Connect(inputLayer, fullyConnectedLayer, inputTensorInfo, 0, 0);
35*89c4ff92SAndroid Build Coastguard Worker     Connect(weightsInputLayer, fullyConnectedLayer, weightsTensorInfo, 0, 1);
36*89c4ff92SAndroid Build Coastguard Worker     Connect(fullyConnectedLayer, outputLayer, outputTensorInfo, 0, 0);
37*89c4ff92SAndroid Build Coastguard Worker 
38*89c4ff92SAndroid Build Coastguard Worker     return network;
39*89c4ff92SAndroid Build Coastguard Worker }
40*89c4ff92SAndroid Build Coastguard Worker 
CreateFullyConnectedNetworkNonConstWeightsConstBias(const armnn::TensorInfo & inputTensorInfo,const armnn::TensorInfo & outputTensorInfo,const armnn::TensorInfo & weightsTensorInfo,const armnn::TensorInfo & biasTensorInfo,const armnn::ConstTensor & biasConstantTensor,armnn::FullyConnectedDescriptor descriptor)41*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr CreateFullyConnectedNetworkNonConstWeightsConstBias(const armnn::TensorInfo& inputTensorInfo,
42*89c4ff92SAndroid Build Coastguard Worker                                                                        const armnn::TensorInfo& outputTensorInfo,
43*89c4ff92SAndroid Build Coastguard Worker                                                                        const armnn::TensorInfo& weightsTensorInfo,
44*89c4ff92SAndroid Build Coastguard Worker                                                                        const armnn::TensorInfo& biasTensorInfo,
45*89c4ff92SAndroid Build Coastguard Worker                                                                        const armnn::ConstTensor& biasConstantTensor,
46*89c4ff92SAndroid Build Coastguard Worker                                                                        armnn::FullyConnectedDescriptor descriptor)
47*89c4ff92SAndroid Build Coastguard Worker {
48*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network(armnn::INetwork::Create());
49*89c4ff92SAndroid Build Coastguard Worker 
50*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* inputLayer  = network->AddInputLayer(0, "Input");
51*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* weightsInputLayer   = network->AddInputLayer(1, "Weights_Input");
52*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* biasLayer  = network->AddConstantLayer(biasConstantTensor, "Weights");
53*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* fullyConnectedLayer = network->AddFullyConnectedLayer(descriptor, "Fully_Connected");
54*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* outputLayer = network->AddOutputLayer(0, "Output");
55*89c4ff92SAndroid Build Coastguard Worker 
56*89c4ff92SAndroid Build Coastguard Worker     Connect(inputLayer, fullyConnectedLayer, inputTensorInfo, 0, 0);
57*89c4ff92SAndroid Build Coastguard Worker     Connect(weightsInputLayer, fullyConnectedLayer, weightsTensorInfo, 0, 1);
58*89c4ff92SAndroid Build Coastguard Worker     Connect(biasLayer, fullyConnectedLayer, biasTensorInfo, 0, 2);
59*89c4ff92SAndroid Build Coastguard Worker     Connect(fullyConnectedLayer, outputLayer, outputTensorInfo, 0, 0);
60*89c4ff92SAndroid Build Coastguard Worker 
61*89c4ff92SAndroid Build Coastguard Worker     return network;
62*89c4ff92SAndroid Build Coastguard Worker }
63*89c4ff92SAndroid Build Coastguard Worker 
CreateFullyConnectedNetworkConstWeightsNonConstBias(const armnn::TensorInfo & inputTensorInfo,const armnn::TensorInfo & outputTensorInfo,const armnn::TensorInfo & weightsTensorInfo,const armnn::TensorInfo & biasTensorInfo,const armnn::ConstTensor & weightsConstantTensor,armnn::FullyConnectedDescriptor descriptor)64*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr CreateFullyConnectedNetworkConstWeightsNonConstBias(const armnn::TensorInfo& inputTensorInfo,
65*89c4ff92SAndroid Build Coastguard Worker                                                                        const armnn::TensorInfo& outputTensorInfo,
66*89c4ff92SAndroid Build Coastguard Worker                                                                        const armnn::TensorInfo& weightsTensorInfo,
67*89c4ff92SAndroid Build Coastguard Worker                                                                        const armnn::TensorInfo& biasTensorInfo,
68*89c4ff92SAndroid Build Coastguard Worker                                                                        const armnn::ConstTensor& weightsConstantTensor,
69*89c4ff92SAndroid Build Coastguard Worker                                                                        armnn::FullyConnectedDescriptor descriptor)
70*89c4ff92SAndroid Build Coastguard Worker {
71*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network(armnn::INetwork::Create());
72*89c4ff92SAndroid Build Coastguard Worker 
73*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* inputLayer  = network->AddInputLayer(0, "Input");
74*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* weightsLayer  = network->AddConstantLayer(weightsConstantTensor, "Weights");
75*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* biasLayer   = network->AddInputLayer(2, "Bias_Input");
76*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* fullyConnectedLayer = network->AddFullyConnectedLayer(descriptor, "Fully_Connected");
77*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* outputLayer = network->AddOutputLayer(0, "Output");
78*89c4ff92SAndroid Build Coastguard Worker 
79*89c4ff92SAndroid Build Coastguard Worker     Connect(inputLayer, fullyConnectedLayer, inputTensorInfo, 0, 0);
80*89c4ff92SAndroid Build Coastguard Worker     Connect(weightsLayer, fullyConnectedLayer, weightsTensorInfo, 0, 1);
81*89c4ff92SAndroid Build Coastguard Worker     Connect(biasLayer, fullyConnectedLayer, biasTensorInfo, 0, 2);
82*89c4ff92SAndroid Build Coastguard Worker     Connect(fullyConnectedLayer, outputLayer, outputTensorInfo, 0, 0);
83*89c4ff92SAndroid Build Coastguard Worker 
84*89c4ff92SAndroid Build Coastguard Worker     return network;
85*89c4ff92SAndroid Build Coastguard Worker }
86*89c4ff92SAndroid Build Coastguard Worker 
CreateFullyConnectedNetworkNoTensorInfoConstWeights(const armnn::TensorInfo & inputTensorInfo,const armnn::TensorInfo & outputTensorInfo,const armnn::ConstTensor & weightsConstantTensor,armnn::FullyConnectedDescriptor descriptor)87*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr CreateFullyConnectedNetworkNoTensorInfoConstWeights(const armnn::TensorInfo& inputTensorInfo,
88*89c4ff92SAndroid Build Coastguard Worker                                                                        const armnn::TensorInfo& outputTensorInfo,
89*89c4ff92SAndroid Build Coastguard Worker                                                                        const armnn::ConstTensor& weightsConstantTensor,
90*89c4ff92SAndroid Build Coastguard Worker                                                                        armnn::FullyConnectedDescriptor descriptor)
91*89c4ff92SAndroid Build Coastguard Worker {
92*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network(armnn::INetwork::Create());
93*89c4ff92SAndroid Build Coastguard Worker 
94*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* inputLayer  = network->AddInputLayer(0, "Input");
95*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* weightsLayer  = network->AddConstantLayer(weightsConstantTensor, "Weights");
96*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* fullyConnectedLayer = network->AddFullyConnectedLayer(descriptor, "Fully_Connected");
97*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* outputLayer = network->AddOutputLayer(0, "Output");
98*89c4ff92SAndroid Build Coastguard Worker 
99*89c4ff92SAndroid Build Coastguard Worker     Connect(inputLayer, fullyConnectedLayer, inputTensorInfo, 0, 0);
100*89c4ff92SAndroid Build Coastguard Worker     weightsLayer->GetOutputSlot(0).Connect(fullyConnectedLayer->GetInputSlot(1));
101*89c4ff92SAndroid Build Coastguard Worker     Connect(fullyConnectedLayer, outputLayer, outputTensorInfo, 0, 0);
102*89c4ff92SAndroid Build Coastguard Worker 
103*89c4ff92SAndroid Build Coastguard Worker     return network;
104*89c4ff92SAndroid Build Coastguard Worker }
105*89c4ff92SAndroid Build Coastguard Worker 
CreateFullyConnectedNetworkNoConnectedWeightsExplicit(const armnn::TensorInfo & inputTensorInfo,const armnn::TensorInfo & outputTensorInfo,const armnn::TensorInfo & biasTensorInfo,armnn::FullyConnectedDescriptor descriptor)106*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr CreateFullyConnectedNetworkNoConnectedWeightsExplicit(const armnn::TensorInfo& inputTensorInfo,
107*89c4ff92SAndroid Build Coastguard Worker                                                                          const armnn::TensorInfo& outputTensorInfo,
108*89c4ff92SAndroid Build Coastguard Worker                                                                          const armnn::TensorInfo& biasTensorInfo,
109*89c4ff92SAndroid Build Coastguard Worker                                                                          armnn::FullyConnectedDescriptor descriptor)
110*89c4ff92SAndroid Build Coastguard Worker {
111*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network(armnn::INetwork::Create());
112*89c4ff92SAndroid Build Coastguard Worker 
113*89c4ff92SAndroid Build Coastguard Worker 
114*89c4ff92SAndroid Build Coastguard Worker     ConstTensor biases;
115*89c4ff92SAndroid Build Coastguard Worker 
116*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* inputLayer  = network->AddInputLayer(0, "Input");
117*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* biasLayer   = network->AddConstantLayer(biases, "Bias_Input");
118*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* fullyConnectedLayer = network->AddFullyConnectedLayer(descriptor, "Fully_Connected");
119*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* outputLayer = network->AddOutputLayer(0, "Output");
120*89c4ff92SAndroid Build Coastguard Worker 
121*89c4ff92SAndroid Build Coastguard Worker     Connect(inputLayer, fullyConnectedLayer, inputTensorInfo, 0, 0);
122*89c4ff92SAndroid Build Coastguard Worker     Connect(biasLayer, fullyConnectedLayer, biasTensorInfo, 0, 2);
123*89c4ff92SAndroid Build Coastguard Worker     Connect(fullyConnectedLayer, outputLayer, outputTensorInfo, 0, 0);
124*89c4ff92SAndroid Build Coastguard Worker 
125*89c4ff92SAndroid Build Coastguard Worker     return network;
126*89c4ff92SAndroid Build Coastguard Worker }
127*89c4ff92SAndroid Build Coastguard Worker 
CreateFullyConnectedNetworkNoConnectedWeightsAndBias(const armnn::TensorInfo & inputTensorInfo,const armnn::TensorInfo & outputTensorInfo,armnn::FullyConnectedDescriptor descriptor)128*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr CreateFullyConnectedNetworkNoConnectedWeightsAndBias(const armnn::TensorInfo& inputTensorInfo,
129*89c4ff92SAndroid Build Coastguard Worker                                                                         const armnn::TensorInfo& outputTensorInfo,
130*89c4ff92SAndroid Build Coastguard Worker                                                                         armnn::FullyConnectedDescriptor descriptor)
131*89c4ff92SAndroid Build Coastguard Worker {
132*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network(armnn::INetwork::Create());
133*89c4ff92SAndroid Build Coastguard Worker 
134*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* inputLayer  = network->AddInputLayer(0, "Input");
135*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* fullyConnectedLayer = network->AddFullyConnectedLayer(descriptor, "Fully_Connected");
136*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* outputLayer = network->AddOutputLayer(0, "Output");
137*89c4ff92SAndroid Build Coastguard Worker 
138*89c4ff92SAndroid Build Coastguard Worker     Connect(inputLayer, fullyConnectedLayer, inputTensorInfo, 0, 0);
139*89c4ff92SAndroid Build Coastguard Worker     Connect(fullyConnectedLayer, outputLayer, outputTensorInfo, 0, 0);
140*89c4ff92SAndroid Build Coastguard Worker 
141*89c4ff92SAndroid Build Coastguard Worker     return network;
142*89c4ff92SAndroid Build Coastguard Worker }
143*89c4ff92SAndroid Build Coastguard Worker 
CreateFullyConnectedNetworkNoConnectedBiasExplicit(const armnn::TensorInfo & inputTensorInfo,const armnn::TensorInfo & outputTensorInfo,const armnn::TensorInfo & weightsTensorInfo,const armnn::ConstTensor & weightsConstantTensor,armnn::FullyConnectedDescriptor descriptor)144*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr CreateFullyConnectedNetworkNoConnectedBiasExplicit(const armnn::TensorInfo& inputTensorInfo,
145*89c4ff92SAndroid Build Coastguard Worker                                                                       const armnn::TensorInfo& outputTensorInfo,
146*89c4ff92SAndroid Build Coastguard Worker                                                                       const armnn::TensorInfo& weightsTensorInfo,
147*89c4ff92SAndroid Build Coastguard Worker                                                                       const armnn::ConstTensor& weightsConstantTensor,
148*89c4ff92SAndroid Build Coastguard Worker                                                                       armnn::FullyConnectedDescriptor descriptor)
149*89c4ff92SAndroid Build Coastguard Worker {
150*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network(armnn::INetwork::Create());
151*89c4ff92SAndroid Build Coastguard Worker 
152*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* inputLayer  = network->AddInputLayer(0, "Input");
153*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* weightsLayer  = network->AddConstantLayer(weightsConstantTensor, "Weights");
154*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* fullyConnectedLayer = network->AddFullyConnectedLayer(descriptor, "Fully_Connected");
155*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* outputLayer = network->AddOutputLayer(0, "Output");
156*89c4ff92SAndroid Build Coastguard Worker 
157*89c4ff92SAndroid Build Coastguard Worker     Connect(inputLayer, fullyConnectedLayer, inputTensorInfo, 0, 0);
158*89c4ff92SAndroid Build Coastguard Worker     Connect(weightsLayer, fullyConnectedLayer, weightsTensorInfo, 0, 1);
159*89c4ff92SAndroid Build Coastguard Worker     Connect(fullyConnectedLayer, outputLayer, outputTensorInfo, 0, 0);
160*89c4ff92SAndroid Build Coastguard Worker 
161*89c4ff92SAndroid Build Coastguard Worker     return network;
162*89c4ff92SAndroid Build Coastguard Worker }
163*89c4ff92SAndroid Build Coastguard Worker 
164*89c4ff92SAndroid Build Coastguard Worker template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
FullyConnectedWithDynamicWeightsEndToEnd(const std::vector<armnn::BackendId> & backends)165*89c4ff92SAndroid Build Coastguard Worker void FullyConnectedWithDynamicWeightsEndToEnd(const std::vector<armnn::BackendId>& backends)
166*89c4ff92SAndroid Build Coastguard Worker {
167*89c4ff92SAndroid Build Coastguard Worker     using namespace armnn;
168*89c4ff92SAndroid Build Coastguard Worker 
169*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo({ 1, 1, 2, 3 }, ArmnnType);
170*89c4ff92SAndroid Build Coastguard Worker     inputTensorInfo.SetQuantizationScale(0.1f);
171*89c4ff92SAndroid Build Coastguard Worker     inputTensorInfo.SetQuantizationOffset(63);
172*89c4ff92SAndroid Build Coastguard Worker     inputTensorInfo.SetConstant(true);
173*89c4ff92SAndroid Build Coastguard Worker 
174*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo outputTensorInfo({ 1, 2 }, ArmnnType);
175*89c4ff92SAndroid Build Coastguard Worker     outputTensorInfo.SetQuantizationScale(5.f);
176*89c4ff92SAndroid Build Coastguard Worker     outputTensorInfo.SetQuantizationOffset(10);
177*89c4ff92SAndroid Build Coastguard Worker 
178*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo weightsTensorInfo({ 2, 6 }, ArmnnType);
179*89c4ff92SAndroid Build Coastguard Worker     weightsTensorInfo.SetQuantizationScale(0.2f);
180*89c4ff92SAndroid Build Coastguard Worker     weightsTensorInfo.SetQuantizationOffset(93);
181*89c4ff92SAndroid Build Coastguard Worker     weightsTensorInfo.SetConstant(true);
182*89c4ff92SAndroid Build Coastguard Worker 
183*89c4ff92SAndroid Build Coastguard Worker     FullyConnectedDescriptor descriptor;
184*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_ConstantWeights = false;
185*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_BiasEnabled     = false;
186*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_TransposeWeightMatrix = true;
187*89c4ff92SAndroid Build Coastguard Worker 
188*89c4ff92SAndroid Build Coastguard Worker     std::vector<T> inputData {
189*89c4ff92SAndroid Build Coastguard Worker         -1.2f, 6.1f, -3.5f,
190*89c4ff92SAndroid Build Coastguard Worker         18.8f, -5.5f, 2.9f
191*89c4ff92SAndroid Build Coastguard Worker     };
192*89c4ff92SAndroid Build Coastguard Worker 
193*89c4ff92SAndroid Build Coastguard Worker     std::vector<T> weightsData {
194*89c4ff92SAndroid Build Coastguard Worker         -8.4f, 20.0f, -10.4f, -8, 16.4f, -11.8f,
195*89c4ff92SAndroid Build Coastguard Worker         23.4f, 10.4f, -14.0f, -3.8f, -11.8f, 11.4f
196*89c4ff92SAndroid Build Coastguard Worker     };
197*89c4ff92SAndroid Build Coastguard Worker 
198*89c4ff92SAndroid Build Coastguard Worker     std::vector<T> floatExpectedOutputData {
199*89c4ff92SAndroid Build Coastguard Worker         -107.04f, 110.f
200*89c4ff92SAndroid Build Coastguard Worker     };
201*89c4ff92SAndroid Build Coastguard Worker     std::vector<T> expectedOutputData = armnnUtils::QuantizedVector<T>(floatExpectedOutputData);
202*89c4ff92SAndroid Build Coastguard Worker 
203*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = CreateFullyConnectedNetworkNonConstWeights(inputTensorInfo,
204*89c4ff92SAndroid Build Coastguard Worker                                                                             outputTensorInfo,
205*89c4ff92SAndroid Build Coastguard Worker                                                                             weightsTensorInfo,
206*89c4ff92SAndroid Build Coastguard Worker                                                                             descriptor);
207*89c4ff92SAndroid Build Coastguard Worker 
208*89c4ff92SAndroid Build Coastguard Worker     CHECK(network);
209*89c4ff92SAndroid Build Coastguard Worker 
210*89c4ff92SAndroid Build Coastguard Worker     std::map<int, std::vector<T>> inputTensorData    = {{ 0, inputData }, {1, weightsData}};
211*89c4ff92SAndroid Build Coastguard Worker     std::map<int, std::vector<T>> expectedOutputTensorData = {{ 0, expectedOutputData }};
212*89c4ff92SAndroid Build Coastguard Worker 
213*89c4ff92SAndroid Build Coastguard Worker     EndToEndLayerTestImpl<ArmnnType, ArmnnType>(move(network),
214*89c4ff92SAndroid Build Coastguard Worker                                                 inputTensorData,
215*89c4ff92SAndroid Build Coastguard Worker                                                 expectedOutputTensorData,
216*89c4ff92SAndroid Build Coastguard Worker                                                 backends,
217*89c4ff92SAndroid Build Coastguard Worker                                                 1.0f);
218*89c4ff92SAndroid Build Coastguard Worker }
219*89c4ff92SAndroid Build Coastguard Worker 
220*89c4ff92SAndroid Build Coastguard Worker template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
FullyConnectedWithDynamicOrConstantInputsEndToEnd(const std::vector<armnn::BackendId> & backends,const bool transposeWeights,const bool constantWeightsOrBias)221*89c4ff92SAndroid Build Coastguard Worker void FullyConnectedWithDynamicOrConstantInputsEndToEnd(const std::vector<armnn::BackendId>& backends,
222*89c4ff92SAndroid Build Coastguard Worker                                                        const bool transposeWeights,
223*89c4ff92SAndroid Build Coastguard Worker                                                        const bool constantWeightsOrBias)
224*89c4ff92SAndroid Build Coastguard Worker {
225*89c4ff92SAndroid Build Coastguard Worker     unsigned int inputWidth = 1;
226*89c4ff92SAndroid Build Coastguard Worker     unsigned int inputHeight = 1;
227*89c4ff92SAndroid Build Coastguard Worker     unsigned int inputChannels = 5;
228*89c4ff92SAndroid Build Coastguard Worker     unsigned int inputNum = 2;
229*89c4ff92SAndroid Build Coastguard Worker 
230*89c4ff92SAndroid Build Coastguard Worker     unsigned int outputChannels = 3;
231*89c4ff92SAndroid Build Coastguard Worker     unsigned int outputNum = 2;
232*89c4ff92SAndroid Build Coastguard Worker 
233*89c4ff92SAndroid Build Coastguard Worker     unsigned int inputShape[]   = { inputNum, inputChannels, inputHeight, inputWidth };
234*89c4ff92SAndroid Build Coastguard Worker     unsigned int outputShape[]  = { outputNum, outputChannels };
235*89c4ff92SAndroid Build Coastguard Worker     unsigned int weightsShape[] = { inputChannels, outputChannels };
236*89c4ff92SAndroid Build Coastguard Worker 
237*89c4ff92SAndroid Build Coastguard Worker     if (transposeWeights)
238*89c4ff92SAndroid Build Coastguard Worker     {
239*89c4ff92SAndroid Build Coastguard Worker         std::swap(weightsShape[0], weightsShape[1]);
240*89c4ff92SAndroid Build Coastguard Worker     }
241*89c4ff92SAndroid Build Coastguard Worker 
242*89c4ff92SAndroid Build Coastguard Worker     unsigned int biasShape[] = { outputChannels };
243*89c4ff92SAndroid Build Coastguard Worker 
244*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo = armnn::TensorInfo(4, inputShape, armnn::DataType::Float32, 0.0f, 0, true);
245*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo outputTensorInfo = armnn::TensorInfo(2, outputShape, armnn::DataType::Float32);
246*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo weightsDesc = armnn::TensorInfo(2, weightsShape, armnn::DataType::Float32, 0.0f, 0, true);
247*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo biasesDesc = armnn::TensorInfo(1, biasShape, armnn::DataType::Float32, 0.0f, 0, true);
248*89c4ff92SAndroid Build Coastguard Worker 
249*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> input =
250*89c4ff92SAndroid Build Coastguard Worker     {
251*89c4ff92SAndroid Build Coastguard Worker         1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
252*89c4ff92SAndroid Build Coastguard Worker         5.0f, 4.0f, 3.0f, 2.0f, 1.0f
253*89c4ff92SAndroid Build Coastguard Worker     };
254*89c4ff92SAndroid Build Coastguard Worker 
255*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> weights =
256*89c4ff92SAndroid Build Coastguard Worker     {
257*89c4ff92SAndroid Build Coastguard Worker         .5f, 2.f, .5f,
258*89c4ff92SAndroid Build Coastguard Worker         .5f, 2.f, 1.f,
259*89c4ff92SAndroid Build Coastguard Worker         .5f, 2.f, 2.f,
260*89c4ff92SAndroid Build Coastguard Worker         .5f, 2.f, 3.f,
261*89c4ff92SAndroid Build Coastguard Worker         .5f, 2.f, 4.f
262*89c4ff92SAndroid Build Coastguard Worker     };
263*89c4ff92SAndroid Build Coastguard Worker 
264*89c4ff92SAndroid Build Coastguard Worker     if (transposeWeights)
265*89c4ff92SAndroid Build Coastguard Worker     {
266*89c4ff92SAndroid Build Coastguard Worker         weights =
267*89c4ff92SAndroid Build Coastguard Worker         {
268*89c4ff92SAndroid Build Coastguard Worker             .5f, .5f, .5f, .5f, .5f,
269*89c4ff92SAndroid Build Coastguard Worker             2.f, 2.f, 2.f, 2.f, 2.f,
270*89c4ff92SAndroid Build Coastguard Worker             .5f, 1.f, 2.f, 3.f, 4.f
271*89c4ff92SAndroid Build Coastguard Worker         };
272*89c4ff92SAndroid Build Coastguard Worker     }
273*89c4ff92SAndroid Build Coastguard Worker 
274*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> biasValues = std::vector<float>({10.f, 20.f, 30.f});
275*89c4ff92SAndroid Build Coastguard Worker 
276*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expectedOutput =
277*89c4ff92SAndroid Build Coastguard Worker     {
278*89c4ff92SAndroid Build Coastguard Worker         0.5f + 1.0f + 1.5f + 2.0f + 2.5f + biasValues[0],
279*89c4ff92SAndroid Build Coastguard Worker         2.0f + 4.0f + 6.0f + 8.0f + 10.f + biasValues[1],
280*89c4ff92SAndroid Build Coastguard Worker         0.5f + 2.0f + 6.0f + 12.f + 20.f + biasValues[2],
281*89c4ff92SAndroid Build Coastguard Worker 
282*89c4ff92SAndroid Build Coastguard Worker         2.5f + 2.0f + 1.5f + 1.0f + 0.5f + biasValues[0],
283*89c4ff92SAndroid Build Coastguard Worker         10.0f + 8.0f + 6.0f + 4.0f + 2.f + biasValues[1],
284*89c4ff92SAndroid Build Coastguard Worker         2.5f + 4.0f + 6.0f + 6.f + 4.f   + biasValues[2]
285*89c4ff92SAndroid Build Coastguard Worker     };
286*89c4ff92SAndroid Build Coastguard Worker 
287*89c4ff92SAndroid Build Coastguard Worker     FullyConnectedDescriptor descriptor;
288*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_BiasEnabled = true;
289*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_TransposeWeightMatrix = transposeWeights;
290*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_ConstantWeights = constantWeightsOrBias;
291*89c4ff92SAndroid Build Coastguard Worker 
292*89c4ff92SAndroid Build Coastguard Worker     if (!constantWeightsOrBias)
293*89c4ff92SAndroid Build Coastguard Worker     {
294*89c4ff92SAndroid Build Coastguard Worker         // Tests non constant weights and constant bias.
295*89c4ff92SAndroid Build Coastguard Worker         ConstTensor biasConstantTensor(biasesDesc, biasValues.data());
296*89c4ff92SAndroid Build Coastguard Worker 
297*89c4ff92SAndroid Build Coastguard Worker         armnn::INetworkPtr network = CreateFullyConnectedNetworkNonConstWeightsConstBias(inputTensorInfo,
298*89c4ff92SAndroid Build Coastguard Worker                                                                                          outputTensorInfo,
299*89c4ff92SAndroid Build Coastguard Worker                                                                                          weightsDesc,
300*89c4ff92SAndroid Build Coastguard Worker                                                                                          biasesDesc,
301*89c4ff92SAndroid Build Coastguard Worker                                                                                          biasConstantTensor,
302*89c4ff92SAndroid Build Coastguard Worker                                                                                          descriptor);
303*89c4ff92SAndroid Build Coastguard Worker         CHECK(network);
304*89c4ff92SAndroid Build Coastguard Worker 
305*89c4ff92SAndroid Build Coastguard Worker         std::map<int, std::vector<T>> inputTensorData    = {{ 0, input }, {1, weights}};
306*89c4ff92SAndroid Build Coastguard Worker         std::map<int, std::vector<T>> expectedOutputTensorData = {{ 0, expectedOutput }};
307*89c4ff92SAndroid Build Coastguard Worker 
308*89c4ff92SAndroid Build Coastguard Worker         EndToEndLayerTestImpl<ArmnnType, ArmnnType>(move(network),
309*89c4ff92SAndroid Build Coastguard Worker                                                     inputTensorData,
310*89c4ff92SAndroid Build Coastguard Worker                                                     expectedOutputTensorData,
311*89c4ff92SAndroid Build Coastguard Worker                                                     backends,
312*89c4ff92SAndroid Build Coastguard Worker                                                     1.0f);
313*89c4ff92SAndroid Build Coastguard Worker     }
314*89c4ff92SAndroid Build Coastguard Worker     else
315*89c4ff92SAndroid Build Coastguard Worker     {
316*89c4ff92SAndroid Build Coastguard Worker         // Tests constant weights and non constant bias.
317*89c4ff92SAndroid Build Coastguard Worker         ConstTensor weightsConstantTensor(weightsDesc, weights.data());
318*89c4ff92SAndroid Build Coastguard Worker 
319*89c4ff92SAndroid Build Coastguard Worker         armnn::INetworkPtr network = CreateFullyConnectedNetworkConstWeightsNonConstBias(inputTensorInfo,
320*89c4ff92SAndroid Build Coastguard Worker                                                                                          outputTensorInfo,
321*89c4ff92SAndroid Build Coastguard Worker                                                                                          weightsDesc,
322*89c4ff92SAndroid Build Coastguard Worker                                                                                          biasesDesc,
323*89c4ff92SAndroid Build Coastguard Worker                                                                                          weightsConstantTensor,
324*89c4ff92SAndroid Build Coastguard Worker                                                                                          descriptor);
325*89c4ff92SAndroid Build Coastguard Worker         CHECK(network);
326*89c4ff92SAndroid Build Coastguard Worker 
327*89c4ff92SAndroid Build Coastguard Worker         std::map<int, std::vector<T>> inputTensorData    = {{ 0, input }, {2, biasValues}};
328*89c4ff92SAndroid Build Coastguard Worker         std::map<int, std::vector<T>> expectedOutputTensorData = {{ 0, expectedOutput }};
329*89c4ff92SAndroid Build Coastguard Worker 
330*89c4ff92SAndroid Build Coastguard Worker         EndToEndLayerTestImpl<ArmnnType, ArmnnType>(move(network),
331*89c4ff92SAndroid Build Coastguard Worker                                                     inputTensorData,
332*89c4ff92SAndroid Build Coastguard Worker                                                     expectedOutputTensorData,
333*89c4ff92SAndroid Build Coastguard Worker                                                     backends,
334*89c4ff92SAndroid Build Coastguard Worker                                                     1.0f);
335*89c4ff92SAndroid Build Coastguard Worker     }
336*89c4ff92SAndroid Build Coastguard Worker }
337*89c4ff92SAndroid Build Coastguard Worker 
338*89c4ff92SAndroid Build Coastguard Worker template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
FullyConnectedErrorChecking(const std::vector<armnn::BackendId> & backends,const bool explicitCheck,const bool biasEnabled,const bool connectedWeights,const bool connectedBias,const bool tensorInfoSet)339*89c4ff92SAndroid Build Coastguard Worker void FullyConnectedErrorChecking(const std::vector<armnn::BackendId>& backends,
340*89c4ff92SAndroid Build Coastguard Worker                                  const bool explicitCheck,
341*89c4ff92SAndroid Build Coastguard Worker                                  const bool biasEnabled,
342*89c4ff92SAndroid Build Coastguard Worker                                  const bool connectedWeights,
343*89c4ff92SAndroid Build Coastguard Worker                                  const bool connectedBias,
344*89c4ff92SAndroid Build Coastguard Worker                                  const bool tensorInfoSet)
345*89c4ff92SAndroid Build Coastguard Worker {
346*89c4ff92SAndroid Build Coastguard Worker     unsigned int inputWidth = 1;
347*89c4ff92SAndroid Build Coastguard Worker     unsigned int inputHeight = 1;
348*89c4ff92SAndroid Build Coastguard Worker     unsigned int inputChannels = 5;
349*89c4ff92SAndroid Build Coastguard Worker     unsigned int inputNum = 2;
350*89c4ff92SAndroid Build Coastguard Worker 
351*89c4ff92SAndroid Build Coastguard Worker     unsigned int outputChannels = 3;
352*89c4ff92SAndroid Build Coastguard Worker     unsigned int outputNum = 2;
353*89c4ff92SAndroid Build Coastguard Worker 
354*89c4ff92SAndroid Build Coastguard Worker     unsigned int inputShape[]   = { inputNum, inputChannels, inputHeight, inputWidth };
355*89c4ff92SAndroid Build Coastguard Worker     unsigned int outputShape[]  = { outputNum, outputChannels };
356*89c4ff92SAndroid Build Coastguard Worker     unsigned int weightsShape[] = { inputChannels, outputChannels };
357*89c4ff92SAndroid Build Coastguard Worker 
358*89c4ff92SAndroid Build Coastguard Worker     unsigned int biasShape[] = { outputChannels };
359*89c4ff92SAndroid Build Coastguard Worker 
360*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo = armnn::TensorInfo(4, inputShape, armnn::DataType::Float32, 0.0f, 0, true);
361*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo outputTensorInfo = armnn::TensorInfo(2, outputShape, armnn::DataType::Float32);
362*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo weightsDesc = armnn::TensorInfo(2, weightsShape, armnn::DataType::Float32, 0.0f, 0, true);
363*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo biasesDesc = armnn::TensorInfo(1, biasShape, armnn::DataType::Float32, 0.0f, 0, true);
364*89c4ff92SAndroid Build Coastguard Worker 
365*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> weights =
366*89c4ff92SAndroid Build Coastguard Worker     {
367*89c4ff92SAndroid Build Coastguard Worker         .5f, 2.f, .5f,
368*89c4ff92SAndroid Build Coastguard Worker         .5f, 2.f, 1.f,
369*89c4ff92SAndroid Build Coastguard Worker         .5f, 2.f, 2.f,
370*89c4ff92SAndroid Build Coastguard Worker         .5f, 2.f, 3.f,
371*89c4ff92SAndroid Build Coastguard Worker         .5f, 2.f, 4.f
372*89c4ff92SAndroid Build Coastguard Worker     };
373*89c4ff92SAndroid Build Coastguard Worker 
374*89c4ff92SAndroid Build Coastguard Worker     FullyConnectedDescriptor descriptor;
375*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_BiasEnabled = biasEnabled;
376*89c4ff92SAndroid Build Coastguard Worker 
377*89c4ff92SAndroid Build Coastguard Worker     if(explicitCheck)
378*89c4ff92SAndroid Build Coastguard Worker     {
379*89c4ff92SAndroid Build Coastguard Worker         if(!biasEnabled)
380*89c4ff92SAndroid Build Coastguard Worker         {
381*89c4ff92SAndroid Build Coastguard Worker             try
382*89c4ff92SAndroid Build Coastguard Worker             {
383*89c4ff92SAndroid Build Coastguard Worker                 CreateFullyConnectedNetworkNoConnectedWeightsExplicit(inputTensorInfo,
384*89c4ff92SAndroid Build Coastguard Worker                                                                       outputTensorInfo,
385*89c4ff92SAndroid Build Coastguard Worker                                                                       biasesDesc,
386*89c4ff92SAndroid Build Coastguard Worker                                                                       descriptor);
387*89c4ff92SAndroid Build Coastguard Worker                 FAIL("LayerValidationException should have been thrown");
388*89c4ff92SAndroid Build Coastguard Worker             }
389*89c4ff92SAndroid Build Coastguard Worker             catch (const LayerValidationException& exc)
390*89c4ff92SAndroid Build Coastguard Worker             {
391*89c4ff92SAndroid Build Coastguard Worker                 CHECK(strcmp(exc.what(), "Tried to connect bias to FullyConnected layer when bias is not enabled: "
392*89c4ff92SAndroid Build Coastguard Worker                                          "Failed to connect to input slot 2 on FullyConnected layer "
393*89c4ff92SAndroid Build Coastguard Worker                                          "\"Fully_Connected\" as the slot does not exist or is unavailable") == 0);
394*89c4ff92SAndroid Build Coastguard Worker             }
395*89c4ff92SAndroid Build Coastguard Worker         }
396*89c4ff92SAndroid Build Coastguard Worker         else if (!connectedWeights)
397*89c4ff92SAndroid Build Coastguard Worker         {
398*89c4ff92SAndroid Build Coastguard Worker             armnn::INetworkPtr network = CreateFullyConnectedNetworkNoConnectedWeightsExplicit(inputTensorInfo,
399*89c4ff92SAndroid Build Coastguard Worker                                                                                                outputTensorInfo,
400*89c4ff92SAndroid Build Coastguard Worker                                                                                                biasesDesc,
401*89c4ff92SAndroid Build Coastguard Worker                                                                                                descriptor);
402*89c4ff92SAndroid Build Coastguard Worker             CHECK(network);
403*89c4ff92SAndroid Build Coastguard Worker 
404*89c4ff92SAndroid Build Coastguard Worker             // Create runtime in which test will run
405*89c4ff92SAndroid Build Coastguard Worker             IRuntime::CreationOptions options;
406*89c4ff92SAndroid Build Coastguard Worker             IRuntimePtr               runtime(IRuntime::Create(options));
407*89c4ff92SAndroid Build Coastguard Worker 
408*89c4ff92SAndroid Build Coastguard Worker             CHECK_THROWS_AS(Optimize(*network, backends, runtime->GetDeviceSpec()), LayerValidationException);
409*89c4ff92SAndroid Build Coastguard Worker         }
410*89c4ff92SAndroid Build Coastguard Worker         else if (!connectedBias)
411*89c4ff92SAndroid Build Coastguard Worker         {
412*89c4ff92SAndroid Build Coastguard Worker             // Tests with constant weights.
413*89c4ff92SAndroid Build Coastguard Worker             ConstTensor weightsConstantTensor(weightsDesc, weights.data());
414*89c4ff92SAndroid Build Coastguard Worker 
415*89c4ff92SAndroid Build Coastguard Worker             armnn::INetworkPtr network = CreateFullyConnectedNetworkNoConnectedBiasExplicit(inputTensorInfo,
416*89c4ff92SAndroid Build Coastguard Worker                                                                                             outputTensorInfo,
417*89c4ff92SAndroid Build Coastguard Worker                                                                                             weightsDesc,
418*89c4ff92SAndroid Build Coastguard Worker                                                                                             weightsConstantTensor,
419*89c4ff92SAndroid Build Coastguard Worker                                                                                             descriptor);
420*89c4ff92SAndroid Build Coastguard Worker             CHECK(network);
421*89c4ff92SAndroid Build Coastguard Worker 
422*89c4ff92SAndroid Build Coastguard Worker             // Create runtime in which test will run
423*89c4ff92SAndroid Build Coastguard Worker             IRuntime::CreationOptions options;
424*89c4ff92SAndroid Build Coastguard Worker             IRuntimePtr               runtime(IRuntime::Create(options));
425*89c4ff92SAndroid Build Coastguard Worker 
426*89c4ff92SAndroid Build Coastguard Worker             CHECK_THROWS_AS(Optimize(*network, backends, runtime->GetDeviceSpec()), LayerValidationException);
427*89c4ff92SAndroid Build Coastguard Worker         }
428*89c4ff92SAndroid Build Coastguard Worker     }
429*89c4ff92SAndroid Build Coastguard Worker     else if(!connectedWeights && !connectedBias)
430*89c4ff92SAndroid Build Coastguard Worker     {
431*89c4ff92SAndroid Build Coastguard Worker         armnn::INetworkPtr network = CreateFullyConnectedNetworkNoConnectedWeightsAndBias(inputTensorInfo,
432*89c4ff92SAndroid Build Coastguard Worker                                                                                           outputTensorInfo,
433*89c4ff92SAndroid Build Coastguard Worker                                                                                           descriptor);
434*89c4ff92SAndroid Build Coastguard Worker         CHECK(network);
435*89c4ff92SAndroid Build Coastguard Worker 
436*89c4ff92SAndroid Build Coastguard Worker         // Create runtime in which test will run
437*89c4ff92SAndroid Build Coastguard Worker         IRuntime::CreationOptions options;
438*89c4ff92SAndroid Build Coastguard Worker         IRuntimePtr               runtime(IRuntime::Create(options));
439*89c4ff92SAndroid Build Coastguard Worker 
440*89c4ff92SAndroid Build Coastguard Worker         CHECK_THROWS_AS(Optimize(*network, backends, runtime->GetDeviceSpec()), LayerValidationException);
441*89c4ff92SAndroid Build Coastguard Worker     }
442*89c4ff92SAndroid Build Coastguard Worker     else if(!tensorInfoSet)
443*89c4ff92SAndroid Build Coastguard Worker     {
444*89c4ff92SAndroid Build Coastguard Worker         // Tests with constant weights.
445*89c4ff92SAndroid Build Coastguard Worker         ConstTensor weightsConstantTensor(weightsDesc, weights.data());
446*89c4ff92SAndroid Build Coastguard Worker 
447*89c4ff92SAndroid Build Coastguard Worker         armnn::INetworkPtr network = CreateFullyConnectedNetworkNoTensorInfoConstWeights(inputTensorInfo,
448*89c4ff92SAndroid Build Coastguard Worker                                                                                          outputTensorInfo,
449*89c4ff92SAndroid Build Coastguard Worker                                                                                          weightsConstantTensor,
450*89c4ff92SAndroid Build Coastguard Worker                                                                                          descriptor);
451*89c4ff92SAndroid Build Coastguard Worker         CHECK(network);
452*89c4ff92SAndroid Build Coastguard Worker 
453*89c4ff92SAndroid Build Coastguard Worker         // Create runtime in which test will run
454*89c4ff92SAndroid Build Coastguard Worker         IRuntime::CreationOptions options;
455*89c4ff92SAndroid Build Coastguard Worker         IRuntimePtr runtime(IRuntime::Create(options));
456*89c4ff92SAndroid Build Coastguard Worker 
457*89c4ff92SAndroid Build Coastguard Worker         try
458*89c4ff92SAndroid Build Coastguard Worker         {
459*89c4ff92SAndroid Build Coastguard Worker             Optimize(*network, backends, runtime->GetDeviceSpec());
460*89c4ff92SAndroid Build Coastguard Worker             FAIL("LayerValidationException should have been thrown");
461*89c4ff92SAndroid Build Coastguard Worker         }
462*89c4ff92SAndroid Build Coastguard Worker         catch (const LayerValidationException& exc)
463*89c4ff92SAndroid Build Coastguard Worker         {
464*89c4ff92SAndroid Build Coastguard Worker             CHECK(strcmp(exc.what(), "Output slot TensorInfo not set on Constant layer \"Weights\"") == 0);
465*89c4ff92SAndroid Build Coastguard Worker         }
466*89c4ff92SAndroid Build Coastguard Worker     }
467*89c4ff92SAndroid Build Coastguard Worker }
468*89c4ff92SAndroid Build Coastguard Worker 
469*89c4ff92SAndroid Build Coastguard Worker } // anonymous namespace
470