xref: /aosp_15_r20/external/armnn/src/armnnOnnxParser/test/Gemm.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "armnnOnnxParser/IOnnxParser.hpp"
7*89c4ff92SAndroid Build Coastguard Worker #include "ParserPrototxtFixture.hpp"
8*89c4ff92SAndroid Build Coastguard Worker #include "OnnxParserTestUtils.hpp"
9*89c4ff92SAndroid Build Coastguard Worker 
10*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("OnnxParser_Gemm")
11*89c4ff92SAndroid Build Coastguard Worker {
12*89c4ff92SAndroid Build Coastguard Worker 
13*89c4ff92SAndroid Build Coastguard Worker struct GemmFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
14*89c4ff92SAndroid Build Coastguard Worker {
GemmFixtureGemmFixture15*89c4ff92SAndroid Build Coastguard Worker     GemmFixture(const std::string& alpha,
16*89c4ff92SAndroid Build Coastguard Worker                 const std::string& beta,
17*89c4ff92SAndroid Build Coastguard Worker                 const std::string& transA,
18*89c4ff92SAndroid Build Coastguard Worker                 const std::string& transB,
19*89c4ff92SAndroid Build Coastguard Worker                 const std::vector<int>& inputAShape,
20*89c4ff92SAndroid Build Coastguard Worker                 const std::vector<int>& inputBShape,
21*89c4ff92SAndroid Build Coastguard Worker                 const std::vector<int>& inputCShape,
22*89c4ff92SAndroid Build Coastguard Worker                 const std::vector<int>& outputShape)
23*89c4ff92SAndroid Build Coastguard Worker     {
24*89c4ff92SAndroid Build Coastguard Worker         m_Prototext = R"(
25*89c4ff92SAndroid Build Coastguard Worker                     ir_version: 8
26*89c4ff92SAndroid Build Coastguard Worker                     producer_name: "onnx-example"
27*89c4ff92SAndroid Build Coastguard Worker                     graph {
28*89c4ff92SAndroid Build Coastguard Worker                       node {
29*89c4ff92SAndroid Build Coastguard Worker                         input: "A"
30*89c4ff92SAndroid Build Coastguard Worker                         input: "B"
31*89c4ff92SAndroid Build Coastguard Worker                         input: "C"
32*89c4ff92SAndroid Build Coastguard Worker                         output: "Output"
33*89c4ff92SAndroid Build Coastguard Worker                         op_type: "Gemm"
34*89c4ff92SAndroid Build Coastguard Worker                         attribute {
35*89c4ff92SAndroid Build Coastguard Worker                           name: "alpha"
36*89c4ff92SAndroid Build Coastguard Worker                           f: )" + alpha + R"(
37*89c4ff92SAndroid Build Coastguard Worker                           type: FLOAT
38*89c4ff92SAndroid Build Coastguard Worker                         }
39*89c4ff92SAndroid Build Coastguard Worker                         attribute {
40*89c4ff92SAndroid Build Coastguard Worker                           name: "beta"
41*89c4ff92SAndroid Build Coastguard Worker                           f: )" + beta + R"(
42*89c4ff92SAndroid Build Coastguard Worker                           type: FLOAT
43*89c4ff92SAndroid Build Coastguard Worker                         }
44*89c4ff92SAndroid Build Coastguard Worker                         attribute {
45*89c4ff92SAndroid Build Coastguard Worker                           name: "transA"
46*89c4ff92SAndroid Build Coastguard Worker                           i: )" + transA + R"(
47*89c4ff92SAndroid Build Coastguard Worker                           type: INT
48*89c4ff92SAndroid Build Coastguard Worker                         }
49*89c4ff92SAndroid Build Coastguard Worker                         attribute {
50*89c4ff92SAndroid Build Coastguard Worker                           name: "transB"
51*89c4ff92SAndroid Build Coastguard Worker                           i: )" + transB + R"(
52*89c4ff92SAndroid Build Coastguard Worker                           type: INT
53*89c4ff92SAndroid Build Coastguard Worker                         }
54*89c4ff92SAndroid Build Coastguard Worker                       }
55*89c4ff92SAndroid Build Coastguard Worker                       name: "gem-model"
56*89c4ff92SAndroid Build Coastguard Worker                       input {
57*89c4ff92SAndroid Build Coastguard Worker                         name: "A"
58*89c4ff92SAndroid Build Coastguard Worker                         type {
59*89c4ff92SAndroid Build Coastguard Worker                           tensor_type {
60*89c4ff92SAndroid Build Coastguard Worker                             elem_type: 1
61*89c4ff92SAndroid Build Coastguard Worker                             shape {
62*89c4ff92SAndroid Build Coastguard Worker                               )" + armnnUtils::ConstructTensorShapeString(inputAShape) + R"(
63*89c4ff92SAndroid Build Coastguard Worker                             }
64*89c4ff92SAndroid Build Coastguard Worker                           }
65*89c4ff92SAndroid Build Coastguard Worker                         }
66*89c4ff92SAndroid Build Coastguard Worker                       }
67*89c4ff92SAndroid Build Coastguard Worker                       input {
68*89c4ff92SAndroid Build Coastguard Worker                         name: "B"
69*89c4ff92SAndroid Build Coastguard Worker                         type {
70*89c4ff92SAndroid Build Coastguard Worker                           tensor_type {
71*89c4ff92SAndroid Build Coastguard Worker                             elem_type: 1
72*89c4ff92SAndroid Build Coastguard Worker                             shape {
73*89c4ff92SAndroid Build Coastguard Worker                               )" + armnnUtils::ConstructTensorShapeString(inputBShape) + R"(
74*89c4ff92SAndroid Build Coastguard Worker                             }
75*89c4ff92SAndroid Build Coastguard Worker                           }
76*89c4ff92SAndroid Build Coastguard Worker                         }
77*89c4ff92SAndroid Build Coastguard Worker                       }
78*89c4ff92SAndroid Build Coastguard Worker                       input {
79*89c4ff92SAndroid Build Coastguard Worker                         name: "C"
80*89c4ff92SAndroid Build Coastguard Worker                         type {
81*89c4ff92SAndroid Build Coastguard Worker                           tensor_type {
82*89c4ff92SAndroid Build Coastguard Worker                             elem_type: 1
83*89c4ff92SAndroid Build Coastguard Worker                             shape {
84*89c4ff92SAndroid Build Coastguard Worker                               )" + armnnUtils::ConstructTensorShapeString(inputCShape) + R"(
85*89c4ff92SAndroid Build Coastguard Worker                             }
86*89c4ff92SAndroid Build Coastguard Worker                           }
87*89c4ff92SAndroid Build Coastguard Worker                         }
88*89c4ff92SAndroid Build Coastguard Worker                       }
89*89c4ff92SAndroid Build Coastguard Worker                       output {
90*89c4ff92SAndroid Build Coastguard Worker                         name: "Output"
91*89c4ff92SAndroid Build Coastguard Worker                         type {
92*89c4ff92SAndroid Build Coastguard Worker                           tensor_type {
93*89c4ff92SAndroid Build Coastguard Worker                             elem_type: 1
94*89c4ff92SAndroid Build Coastguard Worker                             shape {
95*89c4ff92SAndroid Build Coastguard Worker                               )" + armnnUtils::ConstructTensorShapeString(outputShape) + R"(
96*89c4ff92SAndroid Build Coastguard Worker                             }
97*89c4ff92SAndroid Build Coastguard Worker                           }
98*89c4ff92SAndroid Build Coastguard Worker                         }
99*89c4ff92SAndroid Build Coastguard Worker                       }
100*89c4ff92SAndroid Build Coastguard Worker                     })";
101*89c4ff92SAndroid Build Coastguard Worker     }
102*89c4ff92SAndroid Build Coastguard Worker };
103*89c4ff92SAndroid Build Coastguard Worker 
104*89c4ff92SAndroid Build Coastguard Worker struct GemmAllAttributesFixture : GemmFixture
105*89c4ff92SAndroid Build Coastguard Worker {
GemmAllAttributesFixtureGemmAllAttributesFixture106*89c4ff92SAndroid Build Coastguard Worker     GemmAllAttributesFixture() : GemmFixture("0.25", "0.35", "1", "1", { 4, 3 }, { 5, 4 }, { 5 }, { 3, 5 })
107*89c4ff92SAndroid Build Coastguard Worker     {
108*89c4ff92SAndroid Build Coastguard Worker         Setup();
109*89c4ff92SAndroid Build Coastguard Worker     }
110*89c4ff92SAndroid Build Coastguard Worker };
111*89c4ff92SAndroid Build Coastguard Worker 
112*89c4ff92SAndroid Build Coastguard Worker struct GemmSimpleFixture : GemmFixture
113*89c4ff92SAndroid Build Coastguard Worker {
GemmSimpleFixtureGemmSimpleFixture114*89c4ff92SAndroid Build Coastguard Worker     GemmSimpleFixture() : GemmFixture("1", "1", "0", "0", { 3, 4 }, { 4, 5 }, { 5 }, { 3, 5 })
115*89c4ff92SAndroid Build Coastguard Worker     {
116*89c4ff92SAndroid Build Coastguard Worker         Setup();
117*89c4ff92SAndroid Build Coastguard Worker     }
118*89c4ff92SAndroid Build Coastguard Worker };
119*89c4ff92SAndroid Build Coastguard Worker 
120*89c4ff92SAndroid Build Coastguard Worker struct GemmTransAFixture : GemmFixture
121*89c4ff92SAndroid Build Coastguard Worker {
GemmTransAFixtureGemmTransAFixture122*89c4ff92SAndroid Build Coastguard Worker     GemmTransAFixture() : GemmFixture("1", "1", "1", "0", { 4, 3 }, { 4, 5 }, { 5 }, { 3, 5 })
123*89c4ff92SAndroid Build Coastguard Worker     {
124*89c4ff92SAndroid Build Coastguard Worker         Setup();
125*89c4ff92SAndroid Build Coastguard Worker     }
126*89c4ff92SAndroid Build Coastguard Worker };
127*89c4ff92SAndroid Build Coastguard Worker 
128*89c4ff92SAndroid Build Coastguard Worker struct GemmTransBFixture : GemmFixture
129*89c4ff92SAndroid Build Coastguard Worker {
GemmTransBFixtureGemmTransBFixture130*89c4ff92SAndroid Build Coastguard Worker     GemmTransBFixture() : GemmFixture("1", "1", "0", "1", { 3, 4 }, { 5, 4 }, { 5 }, { 3, 5 })
131*89c4ff92SAndroid Build Coastguard Worker     {
132*89c4ff92SAndroid Build Coastguard Worker         Setup();
133*89c4ff92SAndroid Build Coastguard Worker     }
134*89c4ff92SAndroid Build Coastguard Worker };
135*89c4ff92SAndroid Build Coastguard Worker 
136*89c4ff92SAndroid Build Coastguard Worker struct GemmParseExceptionFixture : GemmFixture
137*89c4ff92SAndroid Build Coastguard Worker {
GemmParseExceptionFixtureGemmParseExceptionFixture138*89c4ff92SAndroid Build Coastguard Worker     GemmParseExceptionFixture() : GemmFixture("1", "1", "0", "1", { 3, 4 }, { 5, 4 }, { 3, 5 }, { 3, 5 }) {}
139*89c4ff92SAndroid Build Coastguard Worker };
140*89c4ff92SAndroid Build Coastguard Worker 
141*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(GemmAllAttributesFixture, "GemmTest")
142*89c4ff92SAndroid Build Coastguard Worker {
143*89c4ff92SAndroid Build Coastguard Worker     RunTest<2, float>({{"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f,
144*89c4ff92SAndroid Build Coastguard Worker                                6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }},
145*89c4ff92SAndroid Build Coastguard Worker                        {"B", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
146*89c4ff92SAndroid Build Coastguard Worker                                6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
147*89c4ff92SAndroid Build Coastguard Worker                                11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
148*89c4ff92SAndroid Build Coastguard Worker                                16.0f, 17.0f, 18.0f, 19.0f, 20.0f }},
149*89c4ff92SAndroid Build Coastguard Worker                        {"C", { 0.10f, 0.20f, 0.30f, 0.40f, 0.50f }}},
150*89c4ff92SAndroid Build Coastguard Worker                       {{"Output", { 15.035f, 45.07f, 75.105f, 105.14f, 135.175f,
151*89c4ff92SAndroid Build Coastguard Worker                                     12.535f, 38.57f, 64.605f, 90.64f, 116.675f,
152*89c4ff92SAndroid Build Coastguard Worker                                     10.035f, 32.07f,  54.105f, 76.14f, 98.175f }}});
153*89c4ff92SAndroid Build Coastguard Worker }
154*89c4ff92SAndroid Build Coastguard Worker 
155*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(GemmSimpleFixture, "GemmSimpleTest")
156*89c4ff92SAndroid Build Coastguard Worker {
157*89c4ff92SAndroid Build Coastguard Worker     RunTest<2, float>({{"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f,
158*89c4ff92SAndroid Build Coastguard Worker                                6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }},
159*89c4ff92SAndroid Build Coastguard Worker                        {"B", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
160*89c4ff92SAndroid Build Coastguard Worker                                6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
161*89c4ff92SAndroid Build Coastguard Worker                                11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
162*89c4ff92SAndroid Build Coastguard Worker                                16.0f, 17.0f, 18.0f, 19.0f, 20.0f }},
163*89c4ff92SAndroid Build Coastguard Worker                        {"C", { 0.10f, 0.20f, 0.30f, 0.40f, 0.50f }}},
164*89c4ff92SAndroid Build Coastguard Worker                       {{"Output", { 332.1f, 374.2f, 416.3f, 458.4f, 500.5f,
165*89c4ff92SAndroid Build Coastguard Worker                                     196.1f, 222.2f, 248.3f, 274.4f, 300.5f,
166*89c4ff92SAndroid Build Coastguard Worker                                     60.1f, 70.2f, 80.3f, 90.4f, 100.5f }}});
167*89c4ff92SAndroid Build Coastguard Worker }
168*89c4ff92SAndroid Build Coastguard Worker 
169*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(GemmTransAFixture, "GemmTransposeATest")
170*89c4ff92SAndroid Build Coastguard Worker {
171*89c4ff92SAndroid Build Coastguard Worker     RunTest<2, float>({{"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f,
172*89c4ff92SAndroid Build Coastguard Worker                                6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }},
173*89c4ff92SAndroid Build Coastguard Worker                        {"B", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
174*89c4ff92SAndroid Build Coastguard Worker                                6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
175*89c4ff92SAndroid Build Coastguard Worker                                11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
176*89c4ff92SAndroid Build Coastguard Worker                                16.0f, 17.0f, 18.0f, 19.0f, 20.0f }},
177*89c4ff92SAndroid Build Coastguard Worker                        {"C", { 0.10f, 0.20f, 0.30f, 0.40f, 0.50f }}},
178*89c4ff92SAndroid Build Coastguard Worker                       {{"Output", { 180.1f, 210.2f, 240.3f, 270.4f, 300.5f,
179*89c4ff92SAndroid Build Coastguard Worker                                     146.1f, 172.2f, 198.3f, 224.4f, 250.5f,
180*89c4ff92SAndroid Build Coastguard Worker                                     112.1f, 134.2f, 156.3f, 178.4f, 200.5f }}});
181*89c4ff92SAndroid Build Coastguard Worker }
182*89c4ff92SAndroid Build Coastguard Worker 
183*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(GemmTransBFixture, "GemmTransposeBTest")
184*89c4ff92SAndroid Build Coastguard Worker {
185*89c4ff92SAndroid Build Coastguard Worker     RunTest<2, float>({{"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f,
186*89c4ff92SAndroid Build Coastguard Worker                                6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }},
187*89c4ff92SAndroid Build Coastguard Worker                        {"B", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
188*89c4ff92SAndroid Build Coastguard Worker                                6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
189*89c4ff92SAndroid Build Coastguard Worker                                11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
190*89c4ff92SAndroid Build Coastguard Worker                                16.0f, 17.0f, 18.0f, 19.0f, 20.0f }},
191*89c4ff92SAndroid Build Coastguard Worker                        {"C", { 0.10f, 0.20f, 0.30f, 0.40f, 0.50f }}},
192*89c4ff92SAndroid Build Coastguard Worker                       {{"Output", { 100.1f, 268.2f, 436.3f, 604.4f, 772.5f,
193*89c4ff92SAndroid Build Coastguard Worker                                     60.1f, 164.2f, 268.3f, 372.4f, 476.5f,
194*89c4ff92SAndroid Build Coastguard Worker                                     20.1f, 60.2f, 100.3f, 140.4f, 180.5f }}});
195*89c4ff92SAndroid Build Coastguard Worker }
196*89c4ff92SAndroid Build Coastguard Worker 
197*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(GemmParseExceptionFixture, "GemmParseExceptionTest")
198*89c4ff92SAndroid Build Coastguard Worker {
199*89c4ff92SAndroid Build Coastguard Worker     // ParseException because Input C is non-constant and has 2 dimension (should be 1 dimension)
200*89c4ff92SAndroid Build Coastguard Worker     CHECK_THROWS_AS(Setup(), armnn::ParseException);
201*89c4ff92SAndroid Build Coastguard Worker }
202*89c4ff92SAndroid Build Coastguard Worker 
203*89c4ff92SAndroid Build Coastguard Worker struct GemmConstantFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
204*89c4ff92SAndroid Build Coastguard Worker {
GemmConstantFixtureGemmConstantFixture205*89c4ff92SAndroid Build Coastguard Worker     GemmConstantFixture()
206*89c4ff92SAndroid Build Coastguard Worker     {
207*89c4ff92SAndroid Build Coastguard Worker         m_Prototext = R"(
208*89c4ff92SAndroid Build Coastguard Worker                     ir_version: 8
209*89c4ff92SAndroid Build Coastguard Worker                     producer_name: "onnx-example"
210*89c4ff92SAndroid Build Coastguard Worker                     graph {
211*89c4ff92SAndroid Build Coastguard Worker                       node {
212*89c4ff92SAndroid Build Coastguard Worker                         input: "A"
213*89c4ff92SAndroid Build Coastguard Worker                         input: "B"
214*89c4ff92SAndroid Build Coastguard Worker                         input: "C"
215*89c4ff92SAndroid Build Coastguard Worker                         output: "Output"
216*89c4ff92SAndroid Build Coastguard Worker                         op_type: "Gemm"
217*89c4ff92SAndroid Build Coastguard Worker                         attribute {
218*89c4ff92SAndroid Build Coastguard Worker                           name: "alpha"
219*89c4ff92SAndroid Build Coastguard Worker                           f: 0.25
220*89c4ff92SAndroid Build Coastguard Worker                           type: FLOAT
221*89c4ff92SAndroid Build Coastguard Worker                         }
222*89c4ff92SAndroid Build Coastguard Worker                         attribute {
223*89c4ff92SAndroid Build Coastguard Worker                           name: "beta"
224*89c4ff92SAndroid Build Coastguard Worker                           f: 0.35
225*89c4ff92SAndroid Build Coastguard Worker                           type: FLOAT
226*89c4ff92SAndroid Build Coastguard Worker                         }
227*89c4ff92SAndroid Build Coastguard Worker                         attribute {
228*89c4ff92SAndroid Build Coastguard Worker                           name: "transA"
229*89c4ff92SAndroid Build Coastguard Worker                           i: 1
230*89c4ff92SAndroid Build Coastguard Worker                           type: INT
231*89c4ff92SAndroid Build Coastguard Worker                         }
232*89c4ff92SAndroid Build Coastguard Worker                         attribute {
233*89c4ff92SAndroid Build Coastguard Worker                           name: "transB"
234*89c4ff92SAndroid Build Coastguard Worker                           i: 1
235*89c4ff92SAndroid Build Coastguard Worker                           type: INT
236*89c4ff92SAndroid Build Coastguard Worker                         }
237*89c4ff92SAndroid Build Coastguard Worker                       }
238*89c4ff92SAndroid Build Coastguard Worker                       name: "gem-model"
239*89c4ff92SAndroid Build Coastguard Worker                       initializer {
240*89c4ff92SAndroid Build Coastguard Worker                         dims: 5
241*89c4ff92SAndroid Build Coastguard Worker                         dims: 4
242*89c4ff92SAndroid Build Coastguard Worker                         data_type: 1
243*89c4ff92SAndroid Build Coastguard Worker                         float_data: 1.0
244*89c4ff92SAndroid Build Coastguard Worker                         float_data: 2.0
245*89c4ff92SAndroid Build Coastguard Worker                         float_data: 3.0
246*89c4ff92SAndroid Build Coastguard Worker                         float_data: 4.0
247*89c4ff92SAndroid Build Coastguard Worker                         float_data: 5.0
248*89c4ff92SAndroid Build Coastguard Worker                         float_data: 6.0
249*89c4ff92SAndroid Build Coastguard Worker                         float_data: 7.0
250*89c4ff92SAndroid Build Coastguard Worker                         float_data: 8.0
251*89c4ff92SAndroid Build Coastguard Worker                         float_data: 9.0
252*89c4ff92SAndroid Build Coastguard Worker                         float_data: 10.0
253*89c4ff92SAndroid Build Coastguard Worker                         float_data: 11.0
254*89c4ff92SAndroid Build Coastguard Worker                         float_data: 12.0
255*89c4ff92SAndroid Build Coastguard Worker                         float_data: 13.0
256*89c4ff92SAndroid Build Coastguard Worker                         float_data: 14.0
257*89c4ff92SAndroid Build Coastguard Worker                         float_data: 15.0
258*89c4ff92SAndroid Build Coastguard Worker                         float_data: 16.0
259*89c4ff92SAndroid Build Coastguard Worker                         float_data: 17.0
260*89c4ff92SAndroid Build Coastguard Worker                         float_data: 18.0
261*89c4ff92SAndroid Build Coastguard Worker                         float_data: 19.0
262*89c4ff92SAndroid Build Coastguard Worker                         float_data: 20.0
263*89c4ff92SAndroid Build Coastguard Worker                         name: "B"
264*89c4ff92SAndroid Build Coastguard Worker                       }
265*89c4ff92SAndroid Build Coastguard Worker                       initializer {
266*89c4ff92SAndroid Build Coastguard Worker                         dims: 1
267*89c4ff92SAndroid Build Coastguard Worker                         dims: 5
268*89c4ff92SAndroid Build Coastguard Worker                         data_type: 1
269*89c4ff92SAndroid Build Coastguard Worker                         float_data: 0.1
270*89c4ff92SAndroid Build Coastguard Worker                         float_data: 0.2
271*89c4ff92SAndroid Build Coastguard Worker                         float_data: 0.3
272*89c4ff92SAndroid Build Coastguard Worker                         float_data: 0.4
273*89c4ff92SAndroid Build Coastguard Worker                         float_data: 0.5
274*89c4ff92SAndroid Build Coastguard Worker                         name: "C"
275*89c4ff92SAndroid Build Coastguard Worker                       }
276*89c4ff92SAndroid Build Coastguard Worker                       input {
277*89c4ff92SAndroid Build Coastguard Worker                         name: "A"
278*89c4ff92SAndroid Build Coastguard Worker                         type {
279*89c4ff92SAndroid Build Coastguard Worker                           tensor_type {
280*89c4ff92SAndroid Build Coastguard Worker                             elem_type: 1
281*89c4ff92SAndroid Build Coastguard Worker                             shape {
282*89c4ff92SAndroid Build Coastguard Worker                               dim {
283*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: 4
284*89c4ff92SAndroid Build Coastguard Worker                               }
285*89c4ff92SAndroid Build Coastguard Worker                               dim {
286*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: 3
287*89c4ff92SAndroid Build Coastguard Worker                               }
288*89c4ff92SAndroid Build Coastguard Worker                             }
289*89c4ff92SAndroid Build Coastguard Worker                           }
290*89c4ff92SAndroid Build Coastguard Worker                         }
291*89c4ff92SAndroid Build Coastguard Worker                       }
292*89c4ff92SAndroid Build Coastguard Worker                       output {
293*89c4ff92SAndroid Build Coastguard Worker                         name: "Output"
294*89c4ff92SAndroid Build Coastguard Worker                         type {
295*89c4ff92SAndroid Build Coastguard Worker                           tensor_type {
296*89c4ff92SAndroid Build Coastguard Worker                             elem_type: 1
297*89c4ff92SAndroid Build Coastguard Worker                             shape {
298*89c4ff92SAndroid Build Coastguard Worker                               dim {
299*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: 3
300*89c4ff92SAndroid Build Coastguard Worker                               }
301*89c4ff92SAndroid Build Coastguard Worker                               dim {
302*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: 5
303*89c4ff92SAndroid Build Coastguard Worker                               }
304*89c4ff92SAndroid Build Coastguard Worker                             }
305*89c4ff92SAndroid Build Coastguard Worker                           }
306*89c4ff92SAndroid Build Coastguard Worker                         }
307*89c4ff92SAndroid Build Coastguard Worker                       }
308*89c4ff92SAndroid Build Coastguard Worker                     })";
309*89c4ff92SAndroid Build Coastguard Worker         Setup();
310*89c4ff92SAndroid Build Coastguard Worker     }
311*89c4ff92SAndroid Build Coastguard Worker };
312*89c4ff92SAndroid Build Coastguard Worker 
313*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(GemmConstantFixture, "GemmConstantTest")
314*89c4ff92SAndroid Build Coastguard Worker {
315*89c4ff92SAndroid Build Coastguard Worker     RunTest<2, float>({{"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f,
316*89c4ff92SAndroid Build Coastguard Worker                                6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }}},
317*89c4ff92SAndroid Build Coastguard Worker                       {{"Output", { 15.035f, 45.07f, 75.105f, 105.14f, 135.175f,
318*89c4ff92SAndroid Build Coastguard Worker                                     12.535f, 38.57f, 64.605f, 90.64f, 116.675f,
319*89c4ff92SAndroid Build Coastguard Worker                                     10.035f, 32.07f,  54.105f, 76.14f, 98.175f }}});
320*89c4ff92SAndroid Build Coastguard Worker }
321*89c4ff92SAndroid Build Coastguard Worker 
322*89c4ff92SAndroid Build Coastguard Worker struct GemmConstantSimpleFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
323*89c4ff92SAndroid Build Coastguard Worker {
GemmConstantSimpleFixtureGemmConstantSimpleFixture324*89c4ff92SAndroid Build Coastguard Worker     GemmConstantSimpleFixture()
325*89c4ff92SAndroid Build Coastguard Worker     {
326*89c4ff92SAndroid Build Coastguard Worker         m_Prototext = R"(
327*89c4ff92SAndroid Build Coastguard Worker                     ir_version: 8
328*89c4ff92SAndroid Build Coastguard Worker                     producer_name: "onnx-example"
329*89c4ff92SAndroid Build Coastguard Worker                     graph {
330*89c4ff92SAndroid Build Coastguard Worker                       node {
331*89c4ff92SAndroid Build Coastguard Worker                         input: "A"
332*89c4ff92SAndroid Build Coastguard Worker                         input: "B"
333*89c4ff92SAndroid Build Coastguard Worker                         input: "C"
334*89c4ff92SAndroid Build Coastguard Worker                         output: "Output"
335*89c4ff92SAndroid Build Coastguard Worker                         op_type: "Gemm"
336*89c4ff92SAndroid Build Coastguard Worker                         attribute {
337*89c4ff92SAndroid Build Coastguard Worker                           name: "alpha"
338*89c4ff92SAndroid Build Coastguard Worker                           f: 1
339*89c4ff92SAndroid Build Coastguard Worker                           type: FLOAT
340*89c4ff92SAndroid Build Coastguard Worker                         }
341*89c4ff92SAndroid Build Coastguard Worker                         attribute {
342*89c4ff92SAndroid Build Coastguard Worker                           name: "beta"
343*89c4ff92SAndroid Build Coastguard Worker                           f: 1
344*89c4ff92SAndroid Build Coastguard Worker                           type: FLOAT
345*89c4ff92SAndroid Build Coastguard Worker                         }
346*89c4ff92SAndroid Build Coastguard Worker                         attribute {
347*89c4ff92SAndroid Build Coastguard Worker                           name: "transA"
348*89c4ff92SAndroid Build Coastguard Worker                           i: 0
349*89c4ff92SAndroid Build Coastguard Worker                           type: INT
350*89c4ff92SAndroid Build Coastguard Worker                         }
351*89c4ff92SAndroid Build Coastguard Worker                         attribute {
352*89c4ff92SAndroid Build Coastguard Worker                           name: "transB"
353*89c4ff92SAndroid Build Coastguard Worker                           i: 0
354*89c4ff92SAndroid Build Coastguard Worker                           type: INT
355*89c4ff92SAndroid Build Coastguard Worker                         }
356*89c4ff92SAndroid Build Coastguard Worker                       }
357*89c4ff92SAndroid Build Coastguard Worker                       name: "gem-model"
358*89c4ff92SAndroid Build Coastguard Worker                       initializer {
359*89c4ff92SAndroid Build Coastguard Worker                         dims: 4
360*89c4ff92SAndroid Build Coastguard Worker                         dims: 5
361*89c4ff92SAndroid Build Coastguard Worker                         data_type: 1
362*89c4ff92SAndroid Build Coastguard Worker                         float_data: 1.0
363*89c4ff92SAndroid Build Coastguard Worker                         float_data: 2.0
364*89c4ff92SAndroid Build Coastguard Worker                         float_data: 3.0
365*89c4ff92SAndroid Build Coastguard Worker                         float_data: 4.0
366*89c4ff92SAndroid Build Coastguard Worker                         float_data: 5.0
367*89c4ff92SAndroid Build Coastguard Worker                         float_data: 6.0
368*89c4ff92SAndroid Build Coastguard Worker                         float_data: 7.0
369*89c4ff92SAndroid Build Coastguard Worker                         float_data: 8.0
370*89c4ff92SAndroid Build Coastguard Worker                         float_data: 9.0
371*89c4ff92SAndroid Build Coastguard Worker                         float_data: 10.0
372*89c4ff92SAndroid Build Coastguard Worker                         float_data: 11.0
373*89c4ff92SAndroid Build Coastguard Worker                         float_data: 12.0
374*89c4ff92SAndroid Build Coastguard Worker                         float_data: 13.0
375*89c4ff92SAndroid Build Coastguard Worker                         float_data: 14.0
376*89c4ff92SAndroid Build Coastguard Worker                         float_data: 15.0
377*89c4ff92SAndroid Build Coastguard Worker                         float_data: 16.0
378*89c4ff92SAndroid Build Coastguard Worker                         float_data: 17.0
379*89c4ff92SAndroid Build Coastguard Worker                         float_data: 18.0
380*89c4ff92SAndroid Build Coastguard Worker                         float_data: 19.0
381*89c4ff92SAndroid Build Coastguard Worker                         float_data: 20.0
382*89c4ff92SAndroid Build Coastguard Worker                         name: "B"
383*89c4ff92SAndroid Build Coastguard Worker                       }
384*89c4ff92SAndroid Build Coastguard Worker                       initializer {
385*89c4ff92SAndroid Build Coastguard Worker                         dims: 1
386*89c4ff92SAndroid Build Coastguard Worker                         dims: 5
387*89c4ff92SAndroid Build Coastguard Worker                         data_type: 1
388*89c4ff92SAndroid Build Coastguard Worker                         float_data: 0.1
389*89c4ff92SAndroid Build Coastguard Worker                         float_data: 0.2
390*89c4ff92SAndroid Build Coastguard Worker                         float_data: 0.3
391*89c4ff92SAndroid Build Coastguard Worker                         float_data: 0.4
392*89c4ff92SAndroid Build Coastguard Worker                         float_data: 0.5
393*89c4ff92SAndroid Build Coastguard Worker                         name: "C"
394*89c4ff92SAndroid Build Coastguard Worker                       }
395*89c4ff92SAndroid Build Coastguard Worker                       input {
396*89c4ff92SAndroid Build Coastguard Worker                         name: "A"
397*89c4ff92SAndroid Build Coastguard Worker                         type {
398*89c4ff92SAndroid Build Coastguard Worker                           tensor_type {
399*89c4ff92SAndroid Build Coastguard Worker                             elem_type: 1
400*89c4ff92SAndroid Build Coastguard Worker                             shape {
401*89c4ff92SAndroid Build Coastguard Worker                               dim {
402*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: 3
403*89c4ff92SAndroid Build Coastguard Worker                               }
404*89c4ff92SAndroid Build Coastguard Worker                               dim {
405*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: 4
406*89c4ff92SAndroid Build Coastguard Worker                               }
407*89c4ff92SAndroid Build Coastguard Worker                             }
408*89c4ff92SAndroid Build Coastguard Worker                           }
409*89c4ff92SAndroid Build Coastguard Worker                         }
410*89c4ff92SAndroid Build Coastguard Worker                       }
411*89c4ff92SAndroid Build Coastguard Worker                       output {
412*89c4ff92SAndroid Build Coastguard Worker                         name: "Output"
413*89c4ff92SAndroid Build Coastguard Worker                         type {
414*89c4ff92SAndroid Build Coastguard Worker                           tensor_type {
415*89c4ff92SAndroid Build Coastguard Worker                             elem_type: 1
416*89c4ff92SAndroid Build Coastguard Worker                             shape {
417*89c4ff92SAndroid Build Coastguard Worker                               dim {
418*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: 3
419*89c4ff92SAndroid Build Coastguard Worker                               }
420*89c4ff92SAndroid Build Coastguard Worker                               dim {
421*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: 5
422*89c4ff92SAndroid Build Coastguard Worker                               }
423*89c4ff92SAndroid Build Coastguard Worker                             }
424*89c4ff92SAndroid Build Coastguard Worker                           }
425*89c4ff92SAndroid Build Coastguard Worker                         }
426*89c4ff92SAndroid Build Coastguard Worker                       }
427*89c4ff92SAndroid Build Coastguard Worker                     })";
428*89c4ff92SAndroid Build Coastguard Worker         Setup();
429*89c4ff92SAndroid Build Coastguard Worker     }
430*89c4ff92SAndroid Build Coastguard Worker };
431*89c4ff92SAndroid Build Coastguard Worker 
432*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(GemmConstantSimpleFixture, "GemmConstantSimpleTest")
433*89c4ff92SAndroid Build Coastguard Worker {
434*89c4ff92SAndroid Build Coastguard Worker     RunTest<2, float>({{"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f,
435*89c4ff92SAndroid Build Coastguard Worker                                6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }}},
436*89c4ff92SAndroid Build Coastguard Worker                       {{"Output", { 332.1f, 374.2f, 416.3f, 458.4f, 500.5f,
437*89c4ff92SAndroid Build Coastguard Worker                                     196.1f, 222.2f, 248.3f, 274.4f, 300.5f,
438*89c4ff92SAndroid Build Coastguard Worker                                     60.1f, 70.2f, 80.3f, 90.4f, 100.5f }}});
439*89c4ff92SAndroid Build Coastguard Worker }
440*89c4ff92SAndroid Build Coastguard Worker 
441*89c4ff92SAndroid Build Coastguard Worker struct GemmABFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
442*89c4ff92SAndroid Build Coastguard Worker {
GemmABFixtureGemmABFixture443*89c4ff92SAndroid Build Coastguard Worker     GemmABFixture(const std::string& alpha,
444*89c4ff92SAndroid Build Coastguard Worker                   const std::string& beta,
445*89c4ff92SAndroid Build Coastguard Worker                   const std::string& transA,
446*89c4ff92SAndroid Build Coastguard Worker                   const std::string& transB,
447*89c4ff92SAndroid Build Coastguard Worker                   const std::vector<int>& inputAShape,
448*89c4ff92SAndroid Build Coastguard Worker                   const std::vector<int>& inputBShape,
449*89c4ff92SAndroid Build Coastguard Worker                   const std::vector<int>& outputShape)
450*89c4ff92SAndroid Build Coastguard Worker     {
451*89c4ff92SAndroid Build Coastguard Worker         m_Prototext = R"(
452*89c4ff92SAndroid Build Coastguard Worker                     ir_version: 8
453*89c4ff92SAndroid Build Coastguard Worker                     producer_name: "onnx-example"
454*89c4ff92SAndroid Build Coastguard Worker                     graph {
455*89c4ff92SAndroid Build Coastguard Worker                       node {
456*89c4ff92SAndroid Build Coastguard Worker                         input: "A"
457*89c4ff92SAndroid Build Coastguard Worker                         input: "B"
458*89c4ff92SAndroid Build Coastguard Worker                         output: "Output"
459*89c4ff92SAndroid Build Coastguard Worker                         op_type: "Gemm"
460*89c4ff92SAndroid Build Coastguard Worker                         attribute {
461*89c4ff92SAndroid Build Coastguard Worker                           name: "alpha"
462*89c4ff92SAndroid Build Coastguard Worker                           f: )" + alpha + R"(
463*89c4ff92SAndroid Build Coastguard Worker                           type: FLOAT
464*89c4ff92SAndroid Build Coastguard Worker                         }
465*89c4ff92SAndroid Build Coastguard Worker                         attribute {
466*89c4ff92SAndroid Build Coastguard Worker                           name: "beta"
467*89c4ff92SAndroid Build Coastguard Worker                           f: )" + beta + R"(
468*89c4ff92SAndroid Build Coastguard Worker                           type: FLOAT
469*89c4ff92SAndroid Build Coastguard Worker                         }
470*89c4ff92SAndroid Build Coastguard Worker                         attribute {
471*89c4ff92SAndroid Build Coastguard Worker                           name: "transA"
472*89c4ff92SAndroid Build Coastguard Worker                           i: )" + transA + R"(
473*89c4ff92SAndroid Build Coastguard Worker                           type: INT
474*89c4ff92SAndroid Build Coastguard Worker                         }
475*89c4ff92SAndroid Build Coastguard Worker                         attribute {
476*89c4ff92SAndroid Build Coastguard Worker                           name: "transB"
477*89c4ff92SAndroid Build Coastguard Worker                           i: )" + transB + R"(
478*89c4ff92SAndroid Build Coastguard Worker                           type: INT
479*89c4ff92SAndroid Build Coastguard Worker                         }
480*89c4ff92SAndroid Build Coastguard Worker                       }
481*89c4ff92SAndroid Build Coastguard Worker                       name: "gem-model"
482*89c4ff92SAndroid Build Coastguard Worker                       input {
483*89c4ff92SAndroid Build Coastguard Worker                         name: "A"
484*89c4ff92SAndroid Build Coastguard Worker                         type {
485*89c4ff92SAndroid Build Coastguard Worker                           tensor_type {
486*89c4ff92SAndroid Build Coastguard Worker                             elem_type: 1
487*89c4ff92SAndroid Build Coastguard Worker                             shape {
488*89c4ff92SAndroid Build Coastguard Worker                               )" + armnnUtils::ConstructTensorShapeString(inputAShape) + R"(
489*89c4ff92SAndroid Build Coastguard Worker                             }
490*89c4ff92SAndroid Build Coastguard Worker                           }
491*89c4ff92SAndroid Build Coastguard Worker                         }
492*89c4ff92SAndroid Build Coastguard Worker                       }
493*89c4ff92SAndroid Build Coastguard Worker                       input {
494*89c4ff92SAndroid Build Coastguard Worker                         name: "B"
495*89c4ff92SAndroid Build Coastguard Worker                         type {
496*89c4ff92SAndroid Build Coastguard Worker                           tensor_type {
497*89c4ff92SAndroid Build Coastguard Worker                             elem_type: 1
498*89c4ff92SAndroid Build Coastguard Worker                             shape {
499*89c4ff92SAndroid Build Coastguard Worker                               )" + armnnUtils::ConstructTensorShapeString(inputBShape) + R"(
500*89c4ff92SAndroid Build Coastguard Worker                             }
501*89c4ff92SAndroid Build Coastguard Worker                           }
502*89c4ff92SAndroid Build Coastguard Worker                         }
503*89c4ff92SAndroid Build Coastguard Worker                       }
504*89c4ff92SAndroid Build Coastguard Worker                       output {
505*89c4ff92SAndroid Build Coastguard Worker                         name: "Output"
506*89c4ff92SAndroid Build Coastguard Worker                         type {
507*89c4ff92SAndroid Build Coastguard Worker                           tensor_type {
508*89c4ff92SAndroid Build Coastguard Worker                             elem_type: 1
509*89c4ff92SAndroid Build Coastguard Worker                             shape {
510*89c4ff92SAndroid Build Coastguard Worker                               )" + armnnUtils::ConstructTensorShapeString(outputShape) + R"(
511*89c4ff92SAndroid Build Coastguard Worker                             }
512*89c4ff92SAndroid Build Coastguard Worker                           }
513*89c4ff92SAndroid Build Coastguard Worker                         }
514*89c4ff92SAndroid Build Coastguard Worker                       }
515*89c4ff92SAndroid Build Coastguard Worker                     })";
516*89c4ff92SAndroid Build Coastguard Worker         Setup();
517*89c4ff92SAndroid Build Coastguard Worker     }
518*89c4ff92SAndroid Build Coastguard Worker };
519*89c4ff92SAndroid Build Coastguard Worker 
520*89c4ff92SAndroid Build Coastguard Worker struct GemmAlphaTransAFixture : GemmABFixture
521*89c4ff92SAndroid Build Coastguard Worker {
GemmAlphaTransAFixtureGemmAlphaTransAFixture522*89c4ff92SAndroid Build Coastguard Worker     GemmAlphaTransAFixture() : GemmABFixture("0.25", "0.35", "1", "0", { 4, 3 }, { 4, 5 }, { 3, 5 }) {}
523*89c4ff92SAndroid Build Coastguard Worker };
524*89c4ff92SAndroid Build Coastguard Worker 
525*89c4ff92SAndroid Build Coastguard Worker struct GemmAlphaTransBFixture : GemmABFixture
526*89c4ff92SAndroid Build Coastguard Worker {
GemmAlphaTransBFixtureGemmAlphaTransBFixture527*89c4ff92SAndroid Build Coastguard Worker     GemmAlphaTransBFixture() : GemmABFixture("0.25", "0.35", "0", "1", { 3, 4 }, { 5, 4 }, { 3, 5 }) {}
528*89c4ff92SAndroid Build Coastguard Worker };
529*89c4ff92SAndroid Build Coastguard Worker 
530*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(GemmAlphaTransAFixture, "GemmAlphaTransATest")
531*89c4ff92SAndroid Build Coastguard Worker {
532*89c4ff92SAndroid Build Coastguard Worker     RunTest<2, float>({{"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f,
533*89c4ff92SAndroid Build Coastguard Worker                                6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }},
534*89c4ff92SAndroid Build Coastguard Worker                        {"B", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
535*89c4ff92SAndroid Build Coastguard Worker                                6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
536*89c4ff92SAndroid Build Coastguard Worker                                11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
537*89c4ff92SAndroid Build Coastguard Worker                                16.0f, 17.0f, 18.0f, 19.0f, 20.0f }}},
538*89c4ff92SAndroid Build Coastguard Worker                       {{"Output", { 45.0f, 52.5f, 60.0f, 67.5f, 75.0f,
539*89c4ff92SAndroid Build Coastguard Worker                                     36.5f, 43.0f, 49.5f, 56.0f, 62.5f,
540*89c4ff92SAndroid Build Coastguard Worker                                     28.0f, 33.5f, 39.0f, 44.5f, 50.0f }}});
541*89c4ff92SAndroid Build Coastguard Worker }
542*89c4ff92SAndroid Build Coastguard Worker 
543*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(GemmAlphaTransBFixture, "GemmAlphaTransBTest")
544*89c4ff92SAndroid Build Coastguard Worker {
545*89c4ff92SAndroid Build Coastguard Worker     RunTest<2, float>({{"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f,
546*89c4ff92SAndroid Build Coastguard Worker                                6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }},
547*89c4ff92SAndroid Build Coastguard Worker                        {"B", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
548*89c4ff92SAndroid Build Coastguard Worker                                6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
549*89c4ff92SAndroid Build Coastguard Worker                                11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
550*89c4ff92SAndroid Build Coastguard Worker                                16.0f, 17.0f, 18.0f, 19.0f, 20.0f }}},
551*89c4ff92SAndroid Build Coastguard Worker                       {{"Output", { 25.0f, 67.0f, 109.0f, 151.0f, 193.0f,
552*89c4ff92SAndroid Build Coastguard Worker                                     15.0f, 41.0f, 67.0f, 93.0f, 119.0f,
553*89c4ff92SAndroid Build Coastguard Worker                                     5.0f, 15.0f, 25.0f, 35.0f, 45.0f }}});
554*89c4ff92SAndroid Build Coastguard Worker }
555*89c4ff92SAndroid Build Coastguard Worker 
556*89c4ff92SAndroid Build Coastguard Worker }
557