xref: /aosp_15_r20/external/armnn/src/armnnDeserializer/test/DeserializeBatchMatMul.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "ParserFlatbuffersSerializeFixture.hpp"
7*89c4ff92SAndroid Build Coastguard Worker #include <armnnDeserializer/IDeserializer.hpp>
8*89c4ff92SAndroid Build Coastguard Worker 
9*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
10*89c4ff92SAndroid Build Coastguard Worker 
11*89c4ff92SAndroid Build Coastguard Worker #include <string>
12*89c4ff92SAndroid Build Coastguard Worker 
13*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("Deserializer_BatchMatMul")
14*89c4ff92SAndroid Build Coastguard Worker {
15*89c4ff92SAndroid Build Coastguard Worker struct BatchMatMulFixture : public ParserFlatbuffersSerializeFixture
16*89c4ff92SAndroid Build Coastguard Worker {
BatchMatMulFixtureBatchMatMulFixture17*89c4ff92SAndroid Build Coastguard Worker     explicit BatchMatMulFixture(const std::string& inputXShape,
18*89c4ff92SAndroid Build Coastguard Worker                                 const std::string& inputYShape,
19*89c4ff92SAndroid Build Coastguard Worker                                 const std::string& outputShape,
20*89c4ff92SAndroid Build Coastguard Worker                                 const std::string& dataType)
21*89c4ff92SAndroid Build Coastguard Worker     {
22*89c4ff92SAndroid Build Coastguard Worker         m_JsonString = R"(
23*89c4ff92SAndroid Build Coastguard Worker             {
24*89c4ff92SAndroid Build Coastguard Worker                 inputIds:[
25*89c4ff92SAndroid Build Coastguard Worker                     0,
26*89c4ff92SAndroid Build Coastguard Worker                     1
27*89c4ff92SAndroid Build Coastguard Worker                 ],
28*89c4ff92SAndroid Build Coastguard Worker                 outputIds:[
29*89c4ff92SAndroid Build Coastguard Worker                     3
30*89c4ff92SAndroid Build Coastguard Worker                 ],
31*89c4ff92SAndroid Build Coastguard Worker                 layers:[
32*89c4ff92SAndroid Build Coastguard Worker                     {
33*89c4ff92SAndroid Build Coastguard Worker                         layer_type:"InputLayer",
34*89c4ff92SAndroid Build Coastguard Worker                         layer:{
35*89c4ff92SAndroid Build Coastguard Worker                             base:{
36*89c4ff92SAndroid Build Coastguard Worker                                 layerBindingId:0,
37*89c4ff92SAndroid Build Coastguard Worker                                 base:{
38*89c4ff92SAndroid Build Coastguard Worker                                     index:0,
39*89c4ff92SAndroid Build Coastguard Worker                                     layerName:"InputXLayer",
40*89c4ff92SAndroid Build Coastguard Worker                                     layerType:"Input",
41*89c4ff92SAndroid Build Coastguard Worker                                     inputSlots:[
42*89c4ff92SAndroid Build Coastguard Worker                                         {
43*89c4ff92SAndroid Build Coastguard Worker                                             index:0,
44*89c4ff92SAndroid Build Coastguard Worker                                             connection:{
45*89c4ff92SAndroid Build Coastguard Worker                                                 sourceLayerIndex:0,
46*89c4ff92SAndroid Build Coastguard Worker                                                 outputSlotIndex:0
47*89c4ff92SAndroid Build Coastguard Worker                                             },
48*89c4ff92SAndroid Build Coastguard Worker 
49*89c4ff92SAndroid Build Coastguard Worker                                         }
50*89c4ff92SAndroid Build Coastguard Worker                                     ],
51*89c4ff92SAndroid Build Coastguard Worker                                     outputSlots:[
52*89c4ff92SAndroid Build Coastguard Worker                                         {
53*89c4ff92SAndroid Build Coastguard Worker                                             index:0,
54*89c4ff92SAndroid Build Coastguard Worker                                             tensorInfo:{
55*89c4ff92SAndroid Build Coastguard Worker                                                 dimensions:)" + inputXShape + R"(,
56*89c4ff92SAndroid Build Coastguard Worker                                                 dataType:)" + dataType + R"(
57*89c4ff92SAndroid Build Coastguard Worker                                             },
58*89c4ff92SAndroid Build Coastguard Worker 
59*89c4ff92SAndroid Build Coastguard Worker                                         }
60*89c4ff92SAndroid Build Coastguard Worker                                     ],
61*89c4ff92SAndroid Build Coastguard Worker 
62*89c4ff92SAndroid Build Coastguard Worker                                 },
63*89c4ff92SAndroid Build Coastguard Worker 
64*89c4ff92SAndroid Build Coastguard Worker                             }
65*89c4ff92SAndroid Build Coastguard Worker                         },
66*89c4ff92SAndroid Build Coastguard Worker 
67*89c4ff92SAndroid Build Coastguard Worker                     },
68*89c4ff92SAndroid Build Coastguard Worker                     {
69*89c4ff92SAndroid Build Coastguard Worker                         layer_type:"InputLayer",
70*89c4ff92SAndroid Build Coastguard Worker                         layer:{
71*89c4ff92SAndroid Build Coastguard Worker                             base:{
72*89c4ff92SAndroid Build Coastguard Worker                                 layerBindingId:1,
73*89c4ff92SAndroid Build Coastguard Worker                                 base:{
74*89c4ff92SAndroid Build Coastguard Worker                                     index:1,
75*89c4ff92SAndroid Build Coastguard Worker                                     layerName:"InputYLayer",
76*89c4ff92SAndroid Build Coastguard Worker                                     layerType:"Input",
77*89c4ff92SAndroid Build Coastguard Worker                                     inputSlots:[
78*89c4ff92SAndroid Build Coastguard Worker                                         {
79*89c4ff92SAndroid Build Coastguard Worker                                             index:0,
80*89c4ff92SAndroid Build Coastguard Worker                                             connection:{
81*89c4ff92SAndroid Build Coastguard Worker                                                 sourceLayerIndex:0,
82*89c4ff92SAndroid Build Coastguard Worker                                                 outputSlotIndex:0
83*89c4ff92SAndroid Build Coastguard Worker                                             },
84*89c4ff92SAndroid Build Coastguard Worker 
85*89c4ff92SAndroid Build Coastguard Worker                                         }
86*89c4ff92SAndroid Build Coastguard Worker                                     ],
87*89c4ff92SAndroid Build Coastguard Worker                                     outputSlots:[
88*89c4ff92SAndroid Build Coastguard Worker                                         {
89*89c4ff92SAndroid Build Coastguard Worker                                             index:0,
90*89c4ff92SAndroid Build Coastguard Worker                                             tensorInfo:{
91*89c4ff92SAndroid Build Coastguard Worker                                                 dimensions:)" + inputYShape + R"(,
92*89c4ff92SAndroid Build Coastguard Worker                                                 dataType:)" + dataType + R"(
93*89c4ff92SAndroid Build Coastguard Worker                                             },
94*89c4ff92SAndroid Build Coastguard Worker 
95*89c4ff92SAndroid Build Coastguard Worker                                         }
96*89c4ff92SAndroid Build Coastguard Worker                                     ],
97*89c4ff92SAndroid Build Coastguard Worker 
98*89c4ff92SAndroid Build Coastguard Worker                                 },
99*89c4ff92SAndroid Build Coastguard Worker 
100*89c4ff92SAndroid Build Coastguard Worker                             }
101*89c4ff92SAndroid Build Coastguard Worker                         },
102*89c4ff92SAndroid Build Coastguard Worker 
103*89c4ff92SAndroid Build Coastguard Worker                     },
104*89c4ff92SAndroid Build Coastguard Worker                     {
105*89c4ff92SAndroid Build Coastguard Worker                         layer_type:"BatchMatMulLayer",
106*89c4ff92SAndroid Build Coastguard Worker                         layer:{
107*89c4ff92SAndroid Build Coastguard Worker                             base:{
108*89c4ff92SAndroid Build Coastguard Worker                                 index:2,
109*89c4ff92SAndroid Build Coastguard Worker                                 layerName:"BatchMatMulLayer",
110*89c4ff92SAndroid Build Coastguard Worker                                 layerType:"BatchMatMul",
111*89c4ff92SAndroid Build Coastguard Worker                                 inputSlots:[
112*89c4ff92SAndroid Build Coastguard Worker                                     {
113*89c4ff92SAndroid Build Coastguard Worker                                         index:0,
114*89c4ff92SAndroid Build Coastguard Worker                                         connection:{
115*89c4ff92SAndroid Build Coastguard Worker                                             sourceLayerIndex:0,
116*89c4ff92SAndroid Build Coastguard Worker                                             outputSlotIndex:0
117*89c4ff92SAndroid Build Coastguard Worker                                         },
118*89c4ff92SAndroid Build Coastguard Worker 
119*89c4ff92SAndroid Build Coastguard Worker                                     },
120*89c4ff92SAndroid Build Coastguard Worker                                     {
121*89c4ff92SAndroid Build Coastguard Worker                                         index:1,
122*89c4ff92SAndroid Build Coastguard Worker                                         connection:{
123*89c4ff92SAndroid Build Coastguard Worker                                             sourceLayerIndex:1,
124*89c4ff92SAndroid Build Coastguard Worker                                             outputSlotIndex:0
125*89c4ff92SAndroid Build Coastguard Worker                                         },
126*89c4ff92SAndroid Build Coastguard Worker 
127*89c4ff92SAndroid Build Coastguard Worker                                     }
128*89c4ff92SAndroid Build Coastguard Worker                                 ],
129*89c4ff92SAndroid Build Coastguard Worker                                 outputSlots:[
130*89c4ff92SAndroid Build Coastguard Worker                                     {
131*89c4ff92SAndroid Build Coastguard Worker                                         index:0,
132*89c4ff92SAndroid Build Coastguard Worker                                         tensorInfo:{
133*89c4ff92SAndroid Build Coastguard Worker                                             dimensions:)" + outputShape + R"(,
134*89c4ff92SAndroid Build Coastguard Worker                                             dataType:)" + dataType + R"(
135*89c4ff92SAndroid Build Coastguard Worker                                         },
136*89c4ff92SAndroid Build Coastguard Worker 
137*89c4ff92SAndroid Build Coastguard Worker                                     }
138*89c4ff92SAndroid Build Coastguard Worker                                 ],
139*89c4ff92SAndroid Build Coastguard Worker 
140*89c4ff92SAndroid Build Coastguard Worker                             },
141*89c4ff92SAndroid Build Coastguard Worker                             descriptor:{
142*89c4ff92SAndroid Build Coastguard Worker                                 transposeX:false,
143*89c4ff92SAndroid Build Coastguard Worker                                 transposeY:false,
144*89c4ff92SAndroid Build Coastguard Worker                                 adjointX:false,
145*89c4ff92SAndroid Build Coastguard Worker                                 adjointY:false,
146*89c4ff92SAndroid Build Coastguard Worker                                 dataLayoutX:NHWC,
147*89c4ff92SAndroid Build Coastguard Worker                                 dataLayoutY:NHWC
148*89c4ff92SAndroid Build Coastguard Worker                             }
149*89c4ff92SAndroid Build Coastguard Worker                         },
150*89c4ff92SAndroid Build Coastguard Worker 
151*89c4ff92SAndroid Build Coastguard Worker                     },
152*89c4ff92SAndroid Build Coastguard Worker                     {
153*89c4ff92SAndroid Build Coastguard Worker                         layer_type:"OutputLayer",
154*89c4ff92SAndroid Build Coastguard Worker                         layer:{
155*89c4ff92SAndroid Build Coastguard Worker                             base:{
156*89c4ff92SAndroid Build Coastguard Worker                                 layerBindingId:0,
157*89c4ff92SAndroid Build Coastguard Worker                                 base:{
158*89c4ff92SAndroid Build Coastguard Worker                                     index:3,
159*89c4ff92SAndroid Build Coastguard Worker                                     layerName:"OutputLayer",
160*89c4ff92SAndroid Build Coastguard Worker                                     layerType:"Output",
161*89c4ff92SAndroid Build Coastguard Worker                                     inputSlots:[
162*89c4ff92SAndroid Build Coastguard Worker                                         {
163*89c4ff92SAndroid Build Coastguard Worker                                             index:0,
164*89c4ff92SAndroid Build Coastguard Worker                                             connection:{
165*89c4ff92SAndroid Build Coastguard Worker                                                 sourceLayerIndex:2,
166*89c4ff92SAndroid Build Coastguard Worker                                                 outputSlotIndex:0
167*89c4ff92SAndroid Build Coastguard Worker                                             },
168*89c4ff92SAndroid Build Coastguard Worker 
169*89c4ff92SAndroid Build Coastguard Worker                                         }
170*89c4ff92SAndroid Build Coastguard Worker                                     ],
171*89c4ff92SAndroid Build Coastguard Worker                                     outputSlots:[
172*89c4ff92SAndroid Build Coastguard Worker                                         {
173*89c4ff92SAndroid Build Coastguard Worker                                             index:0,
174*89c4ff92SAndroid Build Coastguard Worker                                             tensorInfo:{
175*89c4ff92SAndroid Build Coastguard Worker                                                 dimensions:)" + outputShape + R"(,
176*89c4ff92SAndroid Build Coastguard Worker                                                 dataType:)" + dataType + R"(
177*89c4ff92SAndroid Build Coastguard Worker                                             },
178*89c4ff92SAndroid Build Coastguard Worker 
179*89c4ff92SAndroid Build Coastguard Worker                                         }
180*89c4ff92SAndroid Build Coastguard Worker                                     ],
181*89c4ff92SAndroid Build Coastguard Worker 
182*89c4ff92SAndroid Build Coastguard Worker                                 }
183*89c4ff92SAndroid Build Coastguard Worker                             }
184*89c4ff92SAndroid Build Coastguard Worker                         },
185*89c4ff92SAndroid Build Coastguard Worker 
186*89c4ff92SAndroid Build Coastguard Worker                     }
187*89c4ff92SAndroid Build Coastguard Worker                 ]
188*89c4ff92SAndroid Build Coastguard Worker             }
189*89c4ff92SAndroid Build Coastguard Worker         )";
190*89c4ff92SAndroid Build Coastguard Worker         Setup();
191*89c4ff92SAndroid Build Coastguard Worker     }
192*89c4ff92SAndroid Build Coastguard Worker };
193*89c4ff92SAndroid Build Coastguard Worker 
194*89c4ff92SAndroid Build Coastguard Worker struct SimpleBatchMatMulFixture : BatchMatMulFixture
195*89c4ff92SAndroid Build Coastguard Worker {
SimpleBatchMatMulFixtureSimpleBatchMatMulFixture196*89c4ff92SAndroid Build Coastguard Worker     SimpleBatchMatMulFixture()
197*89c4ff92SAndroid Build Coastguard Worker         : BatchMatMulFixture("[ 1, 2, 2, 1 ]",
198*89c4ff92SAndroid Build Coastguard Worker                              "[ 1, 2, 2, 1 ]",
199*89c4ff92SAndroid Build Coastguard Worker                              "[ 1, 2, 2, 1 ]",
200*89c4ff92SAndroid Build Coastguard Worker                              "Float32")
201*89c4ff92SAndroid Build Coastguard Worker     {}
202*89c4ff92SAndroid Build Coastguard Worker };
203*89c4ff92SAndroid Build Coastguard Worker 
204*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(SimpleBatchMatMulFixture, "SimpleBatchMatMulTest")
205*89c4ff92SAndroid Build Coastguard Worker {
206*89c4ff92SAndroid Build Coastguard Worker     RunTest<4, armnn::DataType::Float32>(
207*89c4ff92SAndroid Build Coastguard Worker         0,
208*89c4ff92SAndroid Build Coastguard Worker         {{"InputXLayer", { 1.0f, 2.0f, 3.0f, 4.0f }},
209*89c4ff92SAndroid Build Coastguard Worker          {"InputYLayer", { 5.0f, 6.0f, 7.0f, 8.0f }}},
210*89c4ff92SAndroid Build Coastguard Worker         {{"OutputLayer", { 19.0f, 22.0f, 43.0f, 50.0f }}});
211*89c4ff92SAndroid Build Coastguard Worker }
212*89c4ff92SAndroid Build Coastguard Worker 
213*89c4ff92SAndroid Build Coastguard Worker }