xref: /aosp_15_r20/external/armnn/src/armnnDeserializer/test/DeserializeBatchMatMul.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "ParserFlatbuffersSerializeFixture.hpp"
7 #include <armnnDeserializer/IDeserializer.hpp>
8 
9 #include <doctest/doctest.h>
10 
11 #include <string>
12 
13 TEST_SUITE("Deserializer_BatchMatMul")
14 {
15 struct BatchMatMulFixture : public ParserFlatbuffersSerializeFixture
16 {
BatchMatMulFixtureBatchMatMulFixture17     explicit BatchMatMulFixture(const std::string& inputXShape,
18                                 const std::string& inputYShape,
19                                 const std::string& outputShape,
20                                 const std::string& dataType)
21     {
22         m_JsonString = R"(
23             {
24                 inputIds:[
25                     0,
26                     1
27                 ],
28                 outputIds:[
29                     3
30                 ],
31                 layers:[
32                     {
33                         layer_type:"InputLayer",
34                         layer:{
35                             base:{
36                                 layerBindingId:0,
37                                 base:{
38                                     index:0,
39                                     layerName:"InputXLayer",
40                                     layerType:"Input",
41                                     inputSlots:[
42                                         {
43                                             index:0,
44                                             connection:{
45                                                 sourceLayerIndex:0,
46                                                 outputSlotIndex:0
47                                             },
48 
49                                         }
50                                     ],
51                                     outputSlots:[
52                                         {
53                                             index:0,
54                                             tensorInfo:{
55                                                 dimensions:)" + inputXShape + R"(,
56                                                 dataType:)" + dataType + R"(
57                                             },
58 
59                                         }
60                                     ],
61 
62                                 },
63 
64                             }
65                         },
66 
67                     },
68                     {
69                         layer_type:"InputLayer",
70                         layer:{
71                             base:{
72                                 layerBindingId:1,
73                                 base:{
74                                     index:1,
75                                     layerName:"InputYLayer",
76                                     layerType:"Input",
77                                     inputSlots:[
78                                         {
79                                             index:0,
80                                             connection:{
81                                                 sourceLayerIndex:0,
82                                                 outputSlotIndex:0
83                                             },
84 
85                                         }
86                                     ],
87                                     outputSlots:[
88                                         {
89                                             index:0,
90                                             tensorInfo:{
91                                                 dimensions:)" + inputYShape + R"(,
92                                                 dataType:)" + dataType + R"(
93                                             },
94 
95                                         }
96                                     ],
97 
98                                 },
99 
100                             }
101                         },
102 
103                     },
104                     {
105                         layer_type:"BatchMatMulLayer",
106                         layer:{
107                             base:{
108                                 index:2,
109                                 layerName:"BatchMatMulLayer",
110                                 layerType:"BatchMatMul",
111                                 inputSlots:[
112                                     {
113                                         index:0,
114                                         connection:{
115                                             sourceLayerIndex:0,
116                                             outputSlotIndex:0
117                                         },
118 
119                                     },
120                                     {
121                                         index:1,
122                                         connection:{
123                                             sourceLayerIndex:1,
124                                             outputSlotIndex:0
125                                         },
126 
127                                     }
128                                 ],
129                                 outputSlots:[
130                                     {
131                                         index:0,
132                                         tensorInfo:{
133                                             dimensions:)" + outputShape + R"(,
134                                             dataType:)" + dataType + R"(
135                                         },
136 
137                                     }
138                                 ],
139 
140                             },
141                             descriptor:{
142                                 transposeX:false,
143                                 transposeY:false,
144                                 adjointX:false,
145                                 adjointY:false,
146                                 dataLayoutX:NHWC,
147                                 dataLayoutY:NHWC
148                             }
149                         },
150 
151                     },
152                     {
153                         layer_type:"OutputLayer",
154                         layer:{
155                             base:{
156                                 layerBindingId:0,
157                                 base:{
158                                     index:3,
159                                     layerName:"OutputLayer",
160                                     layerType:"Output",
161                                     inputSlots:[
162                                         {
163                                             index:0,
164                                             connection:{
165                                                 sourceLayerIndex:2,
166                                                 outputSlotIndex:0
167                                             },
168 
169                                         }
170                                     ],
171                                     outputSlots:[
172                                         {
173                                             index:0,
174                                             tensorInfo:{
175                                                 dimensions:)" + outputShape + R"(,
176                                                 dataType:)" + dataType + R"(
177                                             },
178 
179                                         }
180                                     ],
181 
182                                 }
183                             }
184                         },
185 
186                     }
187                 ]
188             }
189         )";
190         Setup();
191     }
192 };
193 
194 struct SimpleBatchMatMulFixture : BatchMatMulFixture
195 {
SimpleBatchMatMulFixtureSimpleBatchMatMulFixture196     SimpleBatchMatMulFixture()
197         : BatchMatMulFixture("[ 1, 2, 2, 1 ]",
198                              "[ 1, 2, 2, 1 ]",
199                              "[ 1, 2, 2, 1 ]",
200                              "Float32")
201     {}
202 };
203 
204 TEST_CASE_FIXTURE(SimpleBatchMatMulFixture, "SimpleBatchMatMulTest")
205 {
206     RunTest<4, armnn::DataType::Float32>(
207         0,
208         {{"InputXLayer", { 1.0f, 2.0f, 3.0f, 4.0f }},
209          {"InputYLayer", { 5.0f, 6.0f, 7.0f, 8.0f }}},
210         {{"OutputLayer", { 19.0f, 22.0f, 43.0f, 50.0f }}});
211 }
212 
213 }