xref: /aosp_15_r20/external/armnn/src/armnnOnnxParser/test/FullyConnected.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "armnnOnnxParser/IOnnxParser.hpp"
7*89c4ff92SAndroid Build Coastguard Worker #include  "ParserPrototxtFixture.hpp"
8*89c4ff92SAndroid Build Coastguard Worker 
9*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("OnnxParser_FullyConnected")
10*89c4ff92SAndroid Build Coastguard Worker {
11*89c4ff92SAndroid Build Coastguard Worker // A MatMul in isolation, not connected to an add. Should result in a non-biased FullyConnected layer.
12*89c4ff92SAndroid Build Coastguard Worker struct MatMulFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
13*89c4ff92SAndroid Build Coastguard Worker {
MatMulFixtureMatMulFixture14*89c4ff92SAndroid Build Coastguard Worker     MatMulFixture()
15*89c4ff92SAndroid Build Coastguard Worker     {
16*89c4ff92SAndroid Build Coastguard Worker         m_Prototext = R"(
17*89c4ff92SAndroid Build Coastguard Worker                     ir_version: 3
18*89c4ff92SAndroid Build Coastguard Worker                     producer_name:  "CNTK "
19*89c4ff92SAndroid Build Coastguard Worker                     producer_version:  "2.5.1 "
20*89c4ff92SAndroid Build Coastguard Worker                     domain:  "ai.cntk "
21*89c4ff92SAndroid Build Coastguard Worker                     model_version: 1
22*89c4ff92SAndroid Build Coastguard Worker                     graph {
23*89c4ff92SAndroid Build Coastguard Worker                       name:  "CNTKGraph "
24*89c4ff92SAndroid Build Coastguard Worker                       input {
25*89c4ff92SAndroid Build Coastguard Worker                          name: "Input"
26*89c4ff92SAndroid Build Coastguard Worker                          type {
27*89c4ff92SAndroid Build Coastguard Worker                            tensor_type {
28*89c4ff92SAndroid Build Coastguard Worker                              elem_type: 1
29*89c4ff92SAndroid Build Coastguard Worker                              shape {
30*89c4ff92SAndroid Build Coastguard Worker                                dim {
31*89c4ff92SAndroid Build Coastguard Worker                                  dim_value: 1
32*89c4ff92SAndroid Build Coastguard Worker                                }
33*89c4ff92SAndroid Build Coastguard Worker                                dim {
34*89c4ff92SAndroid Build Coastguard Worker                                  dim_value: 1
35*89c4ff92SAndroid Build Coastguard Worker                                }
36*89c4ff92SAndroid Build Coastguard Worker                              }
37*89c4ff92SAndroid Build Coastguard Worker                            }
38*89c4ff92SAndroid Build Coastguard Worker                          }
39*89c4ff92SAndroid Build Coastguard Worker                        }
40*89c4ff92SAndroid Build Coastguard Worker                        input {
41*89c4ff92SAndroid Build Coastguard Worker                           name: "Const"
42*89c4ff92SAndroid Build Coastguard Worker                           type {
43*89c4ff92SAndroid Build Coastguard Worker                             tensor_type {
44*89c4ff92SAndroid Build Coastguard Worker                               elem_type: 1
45*89c4ff92SAndroid Build Coastguard Worker                               shape {
46*89c4ff92SAndroid Build Coastguard Worker                                 dim {
47*89c4ff92SAndroid Build Coastguard Worker                                   dim_value: 1
48*89c4ff92SAndroid Build Coastguard Worker                                 }
49*89c4ff92SAndroid Build Coastguard Worker                                 dim {
50*89c4ff92SAndroid Build Coastguard Worker                                   dim_value: 1
51*89c4ff92SAndroid Build Coastguard Worker                                 }
52*89c4ff92SAndroid Build Coastguard Worker                               }
53*89c4ff92SAndroid Build Coastguard Worker                             }
54*89c4ff92SAndroid Build Coastguard Worker                           }
55*89c4ff92SAndroid Build Coastguard Worker                         }
56*89c4ff92SAndroid Build Coastguard Worker                         initializer {
57*89c4ff92SAndroid Build Coastguard Worker                           dims: 1
58*89c4ff92SAndroid Build Coastguard Worker                           dims: 1
59*89c4ff92SAndroid Build Coastguard Worker                           data_type: 1
60*89c4ff92SAndroid Build Coastguard Worker                           float_data: 17.0
61*89c4ff92SAndroid Build Coastguard Worker                           name: "Const"
62*89c4ff92SAndroid Build Coastguard Worker                        }
63*89c4ff92SAndroid Build Coastguard Worker                        node {
64*89c4ff92SAndroid Build Coastguard Worker                            input: "Input"
65*89c4ff92SAndroid Build Coastguard Worker                            input: "Const"
66*89c4ff92SAndroid Build Coastguard Worker                            output: "Output"
67*89c4ff92SAndroid Build Coastguard Worker                            name: "SimpleMatmul"
68*89c4ff92SAndroid Build Coastguard Worker                            op_type: "MatMul"
69*89c4ff92SAndroid Build Coastguard Worker                        }
70*89c4ff92SAndroid Build Coastguard Worker                       output {
71*89c4ff92SAndroid Build Coastguard Worker                            name:  "Output"
72*89c4ff92SAndroid Build Coastguard Worker                            type {
73*89c4ff92SAndroid Build Coastguard Worker                               tensor_type {
74*89c4ff92SAndroid Build Coastguard Worker                                 elem_type: 1
75*89c4ff92SAndroid Build Coastguard Worker                                 shape {
76*89c4ff92SAndroid Build Coastguard Worker                                   dim {
77*89c4ff92SAndroid Build Coastguard Worker                                      dim_value: 1
78*89c4ff92SAndroid Build Coastguard Worker                                   }
79*89c4ff92SAndroid Build Coastguard Worker                                   dim {
80*89c4ff92SAndroid Build Coastguard Worker                                      dim_value: 1
81*89c4ff92SAndroid Build Coastguard Worker                                   }
82*89c4ff92SAndroid Build Coastguard Worker                                 }
83*89c4ff92SAndroid Build Coastguard Worker                               }
84*89c4ff92SAndroid Build Coastguard Worker                            }
85*89c4ff92SAndroid Build Coastguard Worker                        }
86*89c4ff92SAndroid Build Coastguard Worker                     }
87*89c4ff92SAndroid Build Coastguard Worker                     opset_import {
88*89c4ff92SAndroid Build Coastguard Worker                        version: 7
89*89c4ff92SAndroid Build Coastguard Worker                      })";
90*89c4ff92SAndroid Build Coastguard Worker 
91*89c4ff92SAndroid Build Coastguard Worker         Setup();
92*89c4ff92SAndroid Build Coastguard Worker     }
93*89c4ff92SAndroid Build Coastguard Worker };
94*89c4ff92SAndroid Build Coastguard Worker 
95*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(MatMulFixture, "MatMul")
96*89c4ff92SAndroid Build Coastguard Worker {
97*89c4ff92SAndroid Build Coastguard Worker     RunTest<1>({{"Input", { 2 }}}, {{"Output", { 34 }}});
98*89c4ff92SAndroid Build Coastguard Worker }
99*89c4ff92SAndroid Build Coastguard Worker 
100*89c4ff92SAndroid Build Coastguard Worker // In Onnx fully connected layers are expressed as a MatMul followed by an Add.
101*89c4ff92SAndroid Build Coastguard Worker // The OnnxParser must detect this case and convert them to a FullyConnected layer.
102*89c4ff92SAndroid Build Coastguard Worker struct FullyConnectedFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
103*89c4ff92SAndroid Build Coastguard Worker {
FullyConnectedFixtureFullyConnectedFixture104*89c4ff92SAndroid Build Coastguard Worker     FullyConnectedFixture()
105*89c4ff92SAndroid Build Coastguard Worker     {
106*89c4ff92SAndroid Build Coastguard Worker         m_Prototext = R"(
107*89c4ff92SAndroid Build Coastguard Worker                     ir_version: 3
108*89c4ff92SAndroid Build Coastguard Worker                     producer_name:  "CNTK "
109*89c4ff92SAndroid Build Coastguard Worker                     producer_version:  "2.5.1 "
110*89c4ff92SAndroid Build Coastguard Worker                     domain:  "ai.cntk "
111*89c4ff92SAndroid Build Coastguard Worker                     model_version: 1
112*89c4ff92SAndroid Build Coastguard Worker                     graph {
113*89c4ff92SAndroid Build Coastguard Worker                       name:  "CNTKGraph "
114*89c4ff92SAndroid Build Coastguard Worker                       input {
115*89c4ff92SAndroid Build Coastguard Worker                          name: "Input"
116*89c4ff92SAndroid Build Coastguard Worker                          type {
117*89c4ff92SAndroid Build Coastguard Worker                            tensor_type {
118*89c4ff92SAndroid Build Coastguard Worker                              elem_type: 1
119*89c4ff92SAndroid Build Coastguard Worker                              shape {
120*89c4ff92SAndroid Build Coastguard Worker                                dim {
121*89c4ff92SAndroid Build Coastguard Worker                                  dim_value: 1
122*89c4ff92SAndroid Build Coastguard Worker                                }
123*89c4ff92SAndroid Build Coastguard Worker                                dim {
124*89c4ff92SAndroid Build Coastguard Worker                                  dim_value: 1
125*89c4ff92SAndroid Build Coastguard Worker                                }
126*89c4ff92SAndroid Build Coastguard Worker                              }
127*89c4ff92SAndroid Build Coastguard Worker                            }
128*89c4ff92SAndroid Build Coastguard Worker                          }
129*89c4ff92SAndroid Build Coastguard Worker                        }
130*89c4ff92SAndroid Build Coastguard Worker                        input {
131*89c4ff92SAndroid Build Coastguard Worker                           name: "Weight"
132*89c4ff92SAndroid Build Coastguard Worker                           type {
133*89c4ff92SAndroid Build Coastguard Worker                             tensor_type {
134*89c4ff92SAndroid Build Coastguard Worker                               elem_type: 1
135*89c4ff92SAndroid Build Coastguard Worker                               shape {
136*89c4ff92SAndroid Build Coastguard Worker                                 dim {
137*89c4ff92SAndroid Build Coastguard Worker                                   dim_value: 1
138*89c4ff92SAndroid Build Coastguard Worker                                 }
139*89c4ff92SAndroid Build Coastguard Worker                                 dim {
140*89c4ff92SAndroid Build Coastguard Worker                                   dim_value: 1
141*89c4ff92SAndroid Build Coastguard Worker                                 }
142*89c4ff92SAndroid Build Coastguard Worker                               }
143*89c4ff92SAndroid Build Coastguard Worker                             }
144*89c4ff92SAndroid Build Coastguard Worker                           }
145*89c4ff92SAndroid Build Coastguard Worker                         }
146*89c4ff92SAndroid Build Coastguard Worker                         initializer {
147*89c4ff92SAndroid Build Coastguard Worker                           dims: 1
148*89c4ff92SAndroid Build Coastguard Worker                           dims: 1
149*89c4ff92SAndroid Build Coastguard Worker                           data_type: 1
150*89c4ff92SAndroid Build Coastguard Worker                           float_data: 2
151*89c4ff92SAndroid Build Coastguard Worker                           name: "Weight"
152*89c4ff92SAndroid Build Coastguard Worker                        }
153*89c4ff92SAndroid Build Coastguard Worker                        input {
154*89c4ff92SAndroid Build Coastguard Worker                           name: "Bias"
155*89c4ff92SAndroid Build Coastguard Worker                           type {
156*89c4ff92SAndroid Build Coastguard Worker                             tensor_type {
157*89c4ff92SAndroid Build Coastguard Worker                               elem_type: 1
158*89c4ff92SAndroid Build Coastguard Worker                               shape {
159*89c4ff92SAndroid Build Coastguard Worker                                 dim {
160*89c4ff92SAndroid Build Coastguard Worker                                   dim_value: 1
161*89c4ff92SAndroid Build Coastguard Worker                                 }
162*89c4ff92SAndroid Build Coastguard Worker                               }
163*89c4ff92SAndroid Build Coastguard Worker                             }
164*89c4ff92SAndroid Build Coastguard Worker                           }
165*89c4ff92SAndroid Build Coastguard Worker                         }
166*89c4ff92SAndroid Build Coastguard Worker                         initializer {
167*89c4ff92SAndroid Build Coastguard Worker                           dims: 1
168*89c4ff92SAndroid Build Coastguard Worker                           data_type: 1
169*89c4ff92SAndroid Build Coastguard Worker                           float_data: 1
170*89c4ff92SAndroid Build Coastguard Worker                           name: "Bias"
171*89c4ff92SAndroid Build Coastguard Worker                        }
172*89c4ff92SAndroid Build Coastguard Worker                        node {
173*89c4ff92SAndroid Build Coastguard Worker                            input: "Input"
174*89c4ff92SAndroid Build Coastguard Worker                            input: "Weight"
175*89c4ff92SAndroid Build Coastguard Worker                            output: "AddInput"
176*89c4ff92SAndroid Build Coastguard Worker                            name: "FCMatmul"
177*89c4ff92SAndroid Build Coastguard Worker                            op_type: "MatMul"
178*89c4ff92SAndroid Build Coastguard Worker                        }
179*89c4ff92SAndroid Build Coastguard Worker                        node {
180*89c4ff92SAndroid Build Coastguard Worker                            input: "AddInput"
181*89c4ff92SAndroid Build Coastguard Worker                            input: "Bias"
182*89c4ff92SAndroid Build Coastguard Worker                            output: "Output"
183*89c4ff92SAndroid Build Coastguard Worker                            name: "FCAdd"
184*89c4ff92SAndroid Build Coastguard Worker                            op_type: "Add"
185*89c4ff92SAndroid Build Coastguard Worker                        }
186*89c4ff92SAndroid Build Coastguard Worker                        value_info {
187*89c4ff92SAndroid Build Coastguard Worker                             name: "AddInput"
188*89c4ff92SAndroid Build Coastguard Worker                             type {
189*89c4ff92SAndroid Build Coastguard Worker                               tensor_type {
190*89c4ff92SAndroid Build Coastguard Worker                                 elem_type: 1
191*89c4ff92SAndroid Build Coastguard Worker                                 shape {
192*89c4ff92SAndroid Build Coastguard Worker                                   dim {
193*89c4ff92SAndroid Build Coastguard Worker                                     dim_value: 1
194*89c4ff92SAndroid Build Coastguard Worker                                   }
195*89c4ff92SAndroid Build Coastguard Worker                                   dim {
196*89c4ff92SAndroid Build Coastguard Worker                                     dim_value: 1
197*89c4ff92SAndroid Build Coastguard Worker                                   }
198*89c4ff92SAndroid Build Coastguard Worker                                 }
199*89c4ff92SAndroid Build Coastguard Worker                               }
200*89c4ff92SAndroid Build Coastguard Worker                             }
201*89c4ff92SAndroid Build Coastguard Worker                           }
202*89c4ff92SAndroid Build Coastguard Worker                       output {
203*89c4ff92SAndroid Build Coastguard Worker                            name:  "Output"
204*89c4ff92SAndroid Build Coastguard Worker                            type {
205*89c4ff92SAndroid Build Coastguard Worker                               tensor_type {
206*89c4ff92SAndroid Build Coastguard Worker                                 elem_type: 1
207*89c4ff92SAndroid Build Coastguard Worker                                 shape {
208*89c4ff92SAndroid Build Coastguard Worker                                   dim {
209*89c4ff92SAndroid Build Coastguard Worker                                      dim_value: 1
210*89c4ff92SAndroid Build Coastguard Worker                                   }
211*89c4ff92SAndroid Build Coastguard Worker                                   dim {
212*89c4ff92SAndroid Build Coastguard Worker                                      dim_value: 1
213*89c4ff92SAndroid Build Coastguard Worker                                   }
214*89c4ff92SAndroid Build Coastguard Worker                                 }
215*89c4ff92SAndroid Build Coastguard Worker                               }
216*89c4ff92SAndroid Build Coastguard Worker                            }
217*89c4ff92SAndroid Build Coastguard Worker                        }
218*89c4ff92SAndroid Build Coastguard Worker                     }
219*89c4ff92SAndroid Build Coastguard Worker                     opset_import {
220*89c4ff92SAndroid Build Coastguard Worker                        version: 7
221*89c4ff92SAndroid Build Coastguard Worker                      })";
222*89c4ff92SAndroid Build Coastguard Worker 
223*89c4ff92SAndroid Build Coastguard Worker         Setup();
224*89c4ff92SAndroid Build Coastguard Worker     }
225*89c4ff92SAndroid Build Coastguard Worker };
226*89c4ff92SAndroid Build Coastguard Worker 
227*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(FullyConnectedFixture, "FullyConnected")
228*89c4ff92SAndroid Build Coastguard Worker {
229*89c4ff92SAndroid Build Coastguard Worker     RunTest<1>({{"Input", { 3 }}}, {{"Output", { 7 }}});
230*89c4ff92SAndroid Build Coastguard Worker }
231*89c4ff92SAndroid Build Coastguard Worker 
232*89c4ff92SAndroid Build Coastguard Worker 
233*89c4ff92SAndroid Build Coastguard Worker // Similar to FullyConnectedFixture, but this time the MatMul's output is used by two Adds. This should result
234*89c4ff92SAndroid Build Coastguard Worker // in two FullyConnected layers being created.
235*89c4ff92SAndroid Build Coastguard Worker //      I
236*89c4ff92SAndroid Build Coastguard Worker //      |
237*89c4ff92SAndroid Build Coastguard Worker //      M -- C
238*89c4ff92SAndroid Build Coastguard Worker //     / \'
239*89c4ff92SAndroid Build Coastguard Worker // C-- A  A -- C
240*89c4ff92SAndroid Build Coastguard Worker //     \ /
241*89c4ff92SAndroid Build Coastguard Worker //      A
242*89c4ff92SAndroid Build Coastguard Worker struct MatMulUsedInTwoFcFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
243*89c4ff92SAndroid Build Coastguard Worker {
MatMulUsedInTwoFcFixtureMatMulUsedInTwoFcFixture244*89c4ff92SAndroid Build Coastguard Worker     MatMulUsedInTwoFcFixture()
245*89c4ff92SAndroid Build Coastguard Worker     {
246*89c4ff92SAndroid Build Coastguard Worker         m_Prototext = R"(
247*89c4ff92SAndroid Build Coastguard Worker                     ir_version: 3
248*89c4ff92SAndroid Build Coastguard Worker                     producer_name:  "CNTK "
249*89c4ff92SAndroid Build Coastguard Worker                     producer_version:  "2.5.1 "
250*89c4ff92SAndroid Build Coastguard Worker                     domain:  "ai.cntk "
251*89c4ff92SAndroid Build Coastguard Worker                     model_version: 1
252*89c4ff92SAndroid Build Coastguard Worker                     graph {
253*89c4ff92SAndroid Build Coastguard Worker                       name:  "CNTKGraph "
254*89c4ff92SAndroid Build Coastguard Worker                       input {
255*89c4ff92SAndroid Build Coastguard Worker                          name: "Input"
256*89c4ff92SAndroid Build Coastguard Worker                          type {
257*89c4ff92SAndroid Build Coastguard Worker                            tensor_type {
258*89c4ff92SAndroid Build Coastguard Worker                              elem_type: 1
259*89c4ff92SAndroid Build Coastguard Worker                              shape {
260*89c4ff92SAndroid Build Coastguard Worker                                dim {
261*89c4ff92SAndroid Build Coastguard Worker                                  dim_value: 1
262*89c4ff92SAndroid Build Coastguard Worker                                }
263*89c4ff92SAndroid Build Coastguard Worker                                dim {
264*89c4ff92SAndroid Build Coastguard Worker                                  dim_value: 1
265*89c4ff92SAndroid Build Coastguard Worker                                }
266*89c4ff92SAndroid Build Coastguard Worker                              }
267*89c4ff92SAndroid Build Coastguard Worker                            }
268*89c4ff92SAndroid Build Coastguard Worker                          }
269*89c4ff92SAndroid Build Coastguard Worker                        }
270*89c4ff92SAndroid Build Coastguard Worker                        input {
271*89c4ff92SAndroid Build Coastguard Worker                           name: "Weight"
272*89c4ff92SAndroid Build Coastguard Worker                           type {
273*89c4ff92SAndroid Build Coastguard Worker                             tensor_type {
274*89c4ff92SAndroid Build Coastguard Worker                               elem_type: 1
275*89c4ff92SAndroid Build Coastguard Worker                               shape {
276*89c4ff92SAndroid Build Coastguard Worker                                 dim {
277*89c4ff92SAndroid Build Coastguard Worker                                   dim_value: 1
278*89c4ff92SAndroid Build Coastguard Worker                                 }
279*89c4ff92SAndroid Build Coastguard Worker                                 dim {
280*89c4ff92SAndroid Build Coastguard Worker                                   dim_value: 1
281*89c4ff92SAndroid Build Coastguard Worker                                 }
282*89c4ff92SAndroid Build Coastguard Worker                               }
283*89c4ff92SAndroid Build Coastguard Worker                             }
284*89c4ff92SAndroid Build Coastguard Worker                           }
285*89c4ff92SAndroid Build Coastguard Worker                         }
286*89c4ff92SAndroid Build Coastguard Worker                         initializer {
287*89c4ff92SAndroid Build Coastguard Worker                           dims: 1
288*89c4ff92SAndroid Build Coastguard Worker                           dims: 1
289*89c4ff92SAndroid Build Coastguard Worker                           data_type: 1
290*89c4ff92SAndroid Build Coastguard Worker                           float_data: 2
291*89c4ff92SAndroid Build Coastguard Worker                           name: "Weight"
292*89c4ff92SAndroid Build Coastguard Worker                        }
293*89c4ff92SAndroid Build Coastguard Worker                        input {
294*89c4ff92SAndroid Build Coastguard Worker                           name: "Bias"
295*89c4ff92SAndroid Build Coastguard Worker                           type {
296*89c4ff92SAndroid Build Coastguard Worker                             tensor_type {
297*89c4ff92SAndroid Build Coastguard Worker                               elem_type: 1
298*89c4ff92SAndroid Build Coastguard Worker                               shape {
299*89c4ff92SAndroid Build Coastguard Worker                                 dim {
300*89c4ff92SAndroid Build Coastguard Worker                                   dim_value: 1
301*89c4ff92SAndroid Build Coastguard Worker                                 }
302*89c4ff92SAndroid Build Coastguard Worker                               }
303*89c4ff92SAndroid Build Coastguard Worker                             }
304*89c4ff92SAndroid Build Coastguard Worker                           }
305*89c4ff92SAndroid Build Coastguard Worker                         }
306*89c4ff92SAndroid Build Coastguard Worker                         initializer {
307*89c4ff92SAndroid Build Coastguard Worker                           dims: 1
308*89c4ff92SAndroid Build Coastguard Worker                           data_type: 1
309*89c4ff92SAndroid Build Coastguard Worker                           float_data: 1
310*89c4ff92SAndroid Build Coastguard Worker                           name: "Bias"
311*89c4ff92SAndroid Build Coastguard Worker                        }
312*89c4ff92SAndroid Build Coastguard Worker                        input {
313*89c4ff92SAndroid Build Coastguard Worker                           name: "Bias_1"
314*89c4ff92SAndroid Build Coastguard Worker                           type {
315*89c4ff92SAndroid Build Coastguard Worker                             tensor_type {
316*89c4ff92SAndroid Build Coastguard Worker                               elem_type: 1
317*89c4ff92SAndroid Build Coastguard Worker                               shape {
318*89c4ff92SAndroid Build Coastguard Worker                                 dim {
319*89c4ff92SAndroid Build Coastguard Worker                                   dim_value: 1
320*89c4ff92SAndroid Build Coastguard Worker                                 }
321*89c4ff92SAndroid Build Coastguard Worker                               }
322*89c4ff92SAndroid Build Coastguard Worker                             }
323*89c4ff92SAndroid Build Coastguard Worker                           }
324*89c4ff92SAndroid Build Coastguard Worker                         }
325*89c4ff92SAndroid Build Coastguard Worker                         initializer {
326*89c4ff92SAndroid Build Coastguard Worker                           dims: 1
327*89c4ff92SAndroid Build Coastguard Worker                           data_type: 1
328*89c4ff92SAndroid Build Coastguard Worker                           float_data: 10.0
329*89c4ff92SAndroid Build Coastguard Worker                           name: "Bias_1"
330*89c4ff92SAndroid Build Coastguard Worker                        }
331*89c4ff92SAndroid Build Coastguard Worker                        node {
332*89c4ff92SAndroid Build Coastguard Worker                            input: "Input"
333*89c4ff92SAndroid Build Coastguard Worker                            input: "Weight"
334*89c4ff92SAndroid Build Coastguard Worker                            output: "AddInput"
335*89c4ff92SAndroid Build Coastguard Worker                            name: "FCMatmul"
336*89c4ff92SAndroid Build Coastguard Worker                            op_type: "MatMul"
337*89c4ff92SAndroid Build Coastguard Worker                        }
338*89c4ff92SAndroid Build Coastguard Worker                        node {
339*89c4ff92SAndroid Build Coastguard Worker                            input: "AddInput"
340*89c4ff92SAndroid Build Coastguard Worker                            input: "Bias"
341*89c4ff92SAndroid Build Coastguard Worker                            output: "AddOutput"
342*89c4ff92SAndroid Build Coastguard Worker                            name: "FCAdd"
343*89c4ff92SAndroid Build Coastguard Worker                            op_type: "Add"
344*89c4ff92SAndroid Build Coastguard Worker                        }
345*89c4ff92SAndroid Build Coastguard Worker                        node {
346*89c4ff92SAndroid Build Coastguard Worker                            input: "AddInput"
347*89c4ff92SAndroid Build Coastguard Worker                            input: "Bias_1"
348*89c4ff92SAndroid Build Coastguard Worker                            output: "AddOutput_1"
349*89c4ff92SAndroid Build Coastguard Worker                            name: "FCAdd_1"
350*89c4ff92SAndroid Build Coastguard Worker                            op_type: "Add"
351*89c4ff92SAndroid Build Coastguard Worker                        }
352*89c4ff92SAndroid Build Coastguard Worker                        node {
353*89c4ff92SAndroid Build Coastguard Worker                            input: "AddOutput"
354*89c4ff92SAndroid Build Coastguard Worker                            input: "AddOutput_1"
355*89c4ff92SAndroid Build Coastguard Worker                            output: "Output"
356*89c4ff92SAndroid Build Coastguard Worker                            name: "FinalAdd"
357*89c4ff92SAndroid Build Coastguard Worker                            op_type: "Add"
358*89c4ff92SAndroid Build Coastguard Worker                        }
359*89c4ff92SAndroid Build Coastguard Worker                        value_info {
360*89c4ff92SAndroid Build Coastguard Worker                             name: "AddInput"
361*89c4ff92SAndroid Build Coastguard Worker                             type {
362*89c4ff92SAndroid Build Coastguard Worker                               tensor_type {
363*89c4ff92SAndroid Build Coastguard Worker                                 elem_type: 1
364*89c4ff92SAndroid Build Coastguard Worker                                 shape {
365*89c4ff92SAndroid Build Coastguard Worker                                   dim {
366*89c4ff92SAndroid Build Coastguard Worker                                     dim_value: 1
367*89c4ff92SAndroid Build Coastguard Worker                                   }
368*89c4ff92SAndroid Build Coastguard Worker                                   dim {
369*89c4ff92SAndroid Build Coastguard Worker                                     dim_value: 1
370*89c4ff92SAndroid Build Coastguard Worker                                   }
371*89c4ff92SAndroid Build Coastguard Worker                                 }
372*89c4ff92SAndroid Build Coastguard Worker                               }
373*89c4ff92SAndroid Build Coastguard Worker                             }
374*89c4ff92SAndroid Build Coastguard Worker                           }
375*89c4ff92SAndroid Build Coastguard Worker                       value_info {
376*89c4ff92SAndroid Build Coastguard Worker                            name:  "AddOutput"
377*89c4ff92SAndroid Build Coastguard Worker                            type {
378*89c4ff92SAndroid Build Coastguard Worker                               tensor_type {
379*89c4ff92SAndroid Build Coastguard Worker                                 elem_type: 1
380*89c4ff92SAndroid Build Coastguard Worker                                 shape {
381*89c4ff92SAndroid Build Coastguard Worker                                   dim {
382*89c4ff92SAndroid Build Coastguard Worker                                      dim_value: 1
383*89c4ff92SAndroid Build Coastguard Worker                                   }
384*89c4ff92SAndroid Build Coastguard Worker                                   dim {
385*89c4ff92SAndroid Build Coastguard Worker                                      dim_value: 1
386*89c4ff92SAndroid Build Coastguard Worker                                   }
387*89c4ff92SAndroid Build Coastguard Worker                                 }
388*89c4ff92SAndroid Build Coastguard Worker                               }
389*89c4ff92SAndroid Build Coastguard Worker                            }
390*89c4ff92SAndroid Build Coastguard Worker                        }
391*89c4ff92SAndroid Build Coastguard Worker                        value_info {
392*89c4ff92SAndroid Build Coastguard Worker                             name:  "AddOutput_1"
393*89c4ff92SAndroid Build Coastguard Worker                             type {
394*89c4ff92SAndroid Build Coastguard Worker                                tensor_type {
395*89c4ff92SAndroid Build Coastguard Worker                                  elem_type: 1
396*89c4ff92SAndroid Build Coastguard Worker                                  shape {
397*89c4ff92SAndroid Build Coastguard Worker                                    dim {
398*89c4ff92SAndroid Build Coastguard Worker                                       dim_value: 1
399*89c4ff92SAndroid Build Coastguard Worker                                    }
400*89c4ff92SAndroid Build Coastguard Worker                                    dim {
401*89c4ff92SAndroid Build Coastguard Worker                                       dim_value: 1
402*89c4ff92SAndroid Build Coastguard Worker                                    }
403*89c4ff92SAndroid Build Coastguard Worker                                  }
404*89c4ff92SAndroid Build Coastguard Worker                                }
405*89c4ff92SAndroid Build Coastguard Worker                             }
406*89c4ff92SAndroid Build Coastguard Worker                         }
407*89c4ff92SAndroid Build Coastguard Worker                         output {
408*89c4ff92SAndroid Build Coastguard Worker                              name:  "Output"
409*89c4ff92SAndroid Build Coastguard Worker                              type {
410*89c4ff92SAndroid Build Coastguard Worker                                 tensor_type {
411*89c4ff92SAndroid Build Coastguard Worker                                   elem_type: 1
412*89c4ff92SAndroid Build Coastguard Worker                                   shape {
413*89c4ff92SAndroid Build Coastguard Worker                                     dim {
414*89c4ff92SAndroid Build Coastguard Worker                                        dim_value: 1
415*89c4ff92SAndroid Build Coastguard Worker                                     }
416*89c4ff92SAndroid Build Coastguard Worker                                     dim {
417*89c4ff92SAndroid Build Coastguard Worker                                        dim_value: 1
418*89c4ff92SAndroid Build Coastguard Worker                                     }
419*89c4ff92SAndroid Build Coastguard Worker                                   }
420*89c4ff92SAndroid Build Coastguard Worker                                 }
421*89c4ff92SAndroid Build Coastguard Worker                              }
422*89c4ff92SAndroid Build Coastguard Worker                          }
423*89c4ff92SAndroid Build Coastguard Worker                     }
424*89c4ff92SAndroid Build Coastguard Worker                     opset_import {
425*89c4ff92SAndroid Build Coastguard Worker                        version: 7
426*89c4ff92SAndroid Build Coastguard Worker                      })";
427*89c4ff92SAndroid Build Coastguard Worker 
428*89c4ff92SAndroid Build Coastguard Worker         Setup();
429*89c4ff92SAndroid Build Coastguard Worker     }
430*89c4ff92SAndroid Build Coastguard Worker };
431*89c4ff92SAndroid Build Coastguard Worker 
432*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(MatMulUsedInTwoFcFixture, "MatMulUsedInTwoFc")
433*89c4ff92SAndroid Build Coastguard Worker {
434*89c4ff92SAndroid Build Coastguard Worker     RunTest<1>({{"Input", { 3 }}}, {{"Output", { 23 }}});
435*89c4ff92SAndroid Build Coastguard Worker }
436*89c4ff92SAndroid Build Coastguard Worker 
437*89c4ff92SAndroid Build Coastguard Worker 
438*89c4ff92SAndroid Build Coastguard Worker // Similar to MatMulUsedInTwoFc, but this time the Adds are 'staggered' (see diagram), which means that only one
439*89c4ff92SAndroid Build Coastguard Worker // FullyConnected layer can be created (the other should just be an Add).
440*89c4ff92SAndroid Build Coastguard Worker //        I
441*89c4ff92SAndroid Build Coastguard Worker //        |
442*89c4ff92SAndroid Build Coastguard Worker //        M -- C1
443*89c4ff92SAndroid Build Coastguard Worker //       / \'
444*89c4ff92SAndroid Build Coastguard Worker // C2 -- A  |
445*89c4ff92SAndroid Build Coastguard Worker //       \ /
446*89c4ff92SAndroid Build Coastguard Worker //        A
447*89c4ff92SAndroid Build Coastguard Worker struct MatMulUsedInTwoFcStaggeredFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
448*89c4ff92SAndroid Build Coastguard Worker {
MatMulUsedInTwoFcStaggeredFixtureMatMulUsedInTwoFcStaggeredFixture449*89c4ff92SAndroid Build Coastguard Worker     MatMulUsedInTwoFcStaggeredFixture()
450*89c4ff92SAndroid Build Coastguard Worker     {
451*89c4ff92SAndroid Build Coastguard Worker         m_Prototext = R"(
452*89c4ff92SAndroid Build Coastguard Worker                     ir_version: 3
453*89c4ff92SAndroid Build Coastguard Worker                     producer_name:  "CNTK "
454*89c4ff92SAndroid Build Coastguard Worker                     producer_version:  "2.5.1 "
455*89c4ff92SAndroid Build Coastguard Worker                     domain:  "ai.cntk "
456*89c4ff92SAndroid Build Coastguard Worker                     model_version: 1
457*89c4ff92SAndroid Build Coastguard Worker                     graph {
458*89c4ff92SAndroid Build Coastguard Worker                       name:  "CNTKGraph "
459*89c4ff92SAndroid Build Coastguard Worker                       input {
460*89c4ff92SAndroid Build Coastguard Worker                          name: "Input"
461*89c4ff92SAndroid Build Coastguard Worker                          type {
462*89c4ff92SAndroid Build Coastguard Worker                            tensor_type {
463*89c4ff92SAndroid Build Coastguard Worker                              elem_type: 1
464*89c4ff92SAndroid Build Coastguard Worker                              shape {
465*89c4ff92SAndroid Build Coastguard Worker                                dim {
466*89c4ff92SAndroid Build Coastguard Worker                                  dim_value: 1
467*89c4ff92SAndroid Build Coastguard Worker                                }
468*89c4ff92SAndroid Build Coastguard Worker                                dim {
469*89c4ff92SAndroid Build Coastguard Worker                                  dim_value: 1
470*89c4ff92SAndroid Build Coastguard Worker                                }
471*89c4ff92SAndroid Build Coastguard Worker                              }
472*89c4ff92SAndroid Build Coastguard Worker                            }
473*89c4ff92SAndroid Build Coastguard Worker                          }
474*89c4ff92SAndroid Build Coastguard Worker                        }
475*89c4ff92SAndroid Build Coastguard Worker                        input {
476*89c4ff92SAndroid Build Coastguard Worker                           name: "Weight"
477*89c4ff92SAndroid Build Coastguard Worker                           type {
478*89c4ff92SAndroid Build Coastguard Worker                             tensor_type {
479*89c4ff92SAndroid Build Coastguard Worker                               elem_type: 1
480*89c4ff92SAndroid Build Coastguard Worker                               shape {
481*89c4ff92SAndroid Build Coastguard Worker                                 dim {
482*89c4ff92SAndroid Build Coastguard Worker                                   dim_value: 1
483*89c4ff92SAndroid Build Coastguard Worker                                 }
484*89c4ff92SAndroid Build Coastguard Worker                                 dim {
485*89c4ff92SAndroid Build Coastguard Worker                                   dim_value: 1
486*89c4ff92SAndroid Build Coastguard Worker                                 }
487*89c4ff92SAndroid Build Coastguard Worker                               }
488*89c4ff92SAndroid Build Coastguard Worker                             }
489*89c4ff92SAndroid Build Coastguard Worker                           }
490*89c4ff92SAndroid Build Coastguard Worker                         }
491*89c4ff92SAndroid Build Coastguard Worker                         initializer {
492*89c4ff92SAndroid Build Coastguard Worker                           dims: 1
493*89c4ff92SAndroid Build Coastguard Worker                           dims: 1
494*89c4ff92SAndroid Build Coastguard Worker                           data_type: 1
495*89c4ff92SAndroid Build Coastguard Worker                           float_data: 2
496*89c4ff92SAndroid Build Coastguard Worker                           name: "Weight"
497*89c4ff92SAndroid Build Coastguard Worker                        }
498*89c4ff92SAndroid Build Coastguard Worker                        input {
499*89c4ff92SAndroid Build Coastguard Worker                           name: "Bias"
500*89c4ff92SAndroid Build Coastguard Worker                           type {
501*89c4ff92SAndroid Build Coastguard Worker                             tensor_type {
502*89c4ff92SAndroid Build Coastguard Worker                               elem_type: 1
503*89c4ff92SAndroid Build Coastguard Worker                               shape {
504*89c4ff92SAndroid Build Coastguard Worker                                 dim {
505*89c4ff92SAndroid Build Coastguard Worker                                   dim_value: 1
506*89c4ff92SAndroid Build Coastguard Worker                                 }
507*89c4ff92SAndroid Build Coastguard Worker                               }
508*89c4ff92SAndroid Build Coastguard Worker                             }
509*89c4ff92SAndroid Build Coastguard Worker                           }
510*89c4ff92SAndroid Build Coastguard Worker                         }
511*89c4ff92SAndroid Build Coastguard Worker                         initializer {
512*89c4ff92SAndroid Build Coastguard Worker                           dims: 1
513*89c4ff92SAndroid Build Coastguard Worker                           data_type: 1
514*89c4ff92SAndroid Build Coastguard Worker                           float_data: 1
515*89c4ff92SAndroid Build Coastguard Worker                           name: "Bias"
516*89c4ff92SAndroid Build Coastguard Worker                        }
517*89c4ff92SAndroid Build Coastguard Worker                         node {
518*89c4ff92SAndroid Build Coastguard Worker                            input: "Input"
519*89c4ff92SAndroid Build Coastguard Worker                            input: "Weight"
520*89c4ff92SAndroid Build Coastguard Worker                            output: "AddInput"
521*89c4ff92SAndroid Build Coastguard Worker                            name: "MatmulFC&NFC"
522*89c4ff92SAndroid Build Coastguard Worker                            op_type: "MatMul"
523*89c4ff92SAndroid Build Coastguard Worker                        }
524*89c4ff92SAndroid Build Coastguard Worker                        node {
525*89c4ff92SAndroid Build Coastguard Worker                            input: "AddInput"
526*89c4ff92SAndroid Build Coastguard Worker                            input: "Bias"
527*89c4ff92SAndroid Build Coastguard Worker                            output: "AddOutput"
528*89c4ff92SAndroid Build Coastguard Worker                            name: "FCAdd"
529*89c4ff92SAndroid Build Coastguard Worker                            op_type: "Add"
530*89c4ff92SAndroid Build Coastguard Worker                        }
531*89c4ff92SAndroid Build Coastguard Worker 
532*89c4ff92SAndroid Build Coastguard Worker                        node {
533*89c4ff92SAndroid Build Coastguard Worker                            input: "AddInput"
534*89c4ff92SAndroid Build Coastguard Worker                            input: "AddOutput"
535*89c4ff92SAndroid Build Coastguard Worker                            output: "Output"
536*89c4ff92SAndroid Build Coastguard Worker                            name: "FinalAdd"
537*89c4ff92SAndroid Build Coastguard Worker                            op_type: "Add"
538*89c4ff92SAndroid Build Coastguard Worker                        }
539*89c4ff92SAndroid Build Coastguard Worker                        value_info {
540*89c4ff92SAndroid Build Coastguard Worker                             name: "AddInput"
541*89c4ff92SAndroid Build Coastguard Worker                             type {
542*89c4ff92SAndroid Build Coastguard Worker                               tensor_type {
543*89c4ff92SAndroid Build Coastguard Worker                                 elem_type: 1
544*89c4ff92SAndroid Build Coastguard Worker                                 shape {
545*89c4ff92SAndroid Build Coastguard Worker                                   dim {
546*89c4ff92SAndroid Build Coastguard Worker                                     dim_value: 1
547*89c4ff92SAndroid Build Coastguard Worker                                   }
548*89c4ff92SAndroid Build Coastguard Worker                                   dim {
549*89c4ff92SAndroid Build Coastguard Worker                                     dim_value: 1
550*89c4ff92SAndroid Build Coastguard Worker                                   }
551*89c4ff92SAndroid Build Coastguard Worker                                 }
552*89c4ff92SAndroid Build Coastguard Worker                               }
553*89c4ff92SAndroid Build Coastguard Worker                             }
554*89c4ff92SAndroid Build Coastguard Worker                           }
555*89c4ff92SAndroid Build Coastguard Worker                       value_info {
556*89c4ff92SAndroid Build Coastguard Worker                            name:  "AddOutput"
557*89c4ff92SAndroid Build Coastguard Worker                            type {
558*89c4ff92SAndroid Build Coastguard Worker                               tensor_type {
559*89c4ff92SAndroid Build Coastguard Worker                                 elem_type: 1
560*89c4ff92SAndroid Build Coastguard Worker                                 shape {
561*89c4ff92SAndroid Build Coastguard Worker                                   dim {
562*89c4ff92SAndroid Build Coastguard Worker                                      dim_value: 1
563*89c4ff92SAndroid Build Coastguard Worker                                   }
564*89c4ff92SAndroid Build Coastguard Worker                                   dim {
565*89c4ff92SAndroid Build Coastguard Worker                                      dim_value: 1
566*89c4ff92SAndroid Build Coastguard Worker                                   }
567*89c4ff92SAndroid Build Coastguard Worker                                 }
568*89c4ff92SAndroid Build Coastguard Worker                               }
569*89c4ff92SAndroid Build Coastguard Worker                            }
570*89c4ff92SAndroid Build Coastguard Worker                        }
571*89c4ff92SAndroid Build Coastguard Worker                        output {
572*89c4ff92SAndroid Build Coastguard Worker                              name:  "Output"
573*89c4ff92SAndroid Build Coastguard Worker                              type {
574*89c4ff92SAndroid Build Coastguard Worker                                 tensor_type {
575*89c4ff92SAndroid Build Coastguard Worker                                   elem_type: 1
576*89c4ff92SAndroid Build Coastguard Worker                                   shape {
577*89c4ff92SAndroid Build Coastguard Worker                                     dim {
578*89c4ff92SAndroid Build Coastguard Worker                                        dim_value: 1
579*89c4ff92SAndroid Build Coastguard Worker                                     }
580*89c4ff92SAndroid Build Coastguard Worker                                     dim {
581*89c4ff92SAndroid Build Coastguard Worker                                        dim_value: 1
582*89c4ff92SAndroid Build Coastguard Worker                                     }
583*89c4ff92SAndroid Build Coastguard Worker                                   }
584*89c4ff92SAndroid Build Coastguard Worker                                 }
585*89c4ff92SAndroid Build Coastguard Worker                              }
586*89c4ff92SAndroid Build Coastguard Worker                          }
587*89c4ff92SAndroid Build Coastguard Worker                     }
588*89c4ff92SAndroid Build Coastguard Worker                     opset_import {
589*89c4ff92SAndroid Build Coastguard Worker                        version: 7
590*89c4ff92SAndroid Build Coastguard Worker                      })";
591*89c4ff92SAndroid Build Coastguard Worker         Setup();
592*89c4ff92SAndroid Build Coastguard Worker     }
593*89c4ff92SAndroid Build Coastguard Worker };
594*89c4ff92SAndroid Build Coastguard Worker 
595*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(MatMulUsedInTwoFcStaggeredFixture, "MatMulUsedInTwoFcStaggered")
596*89c4ff92SAndroid Build Coastguard Worker {
597*89c4ff92SAndroid Build Coastguard Worker     RunTest<1>({{"Input", { 3 }}}, {{"Output", { 13 }}});
598*89c4ff92SAndroid Build Coastguard Worker }
599*89c4ff92SAndroid Build Coastguard Worker 
600*89c4ff92SAndroid Build Coastguard Worker }
601