xref: /aosp_15_r20/external/armnn/src/armnnDeserializer/test/DeserializeMultiplication.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "ParserFlatbuffersSerializeFixture.hpp"
7 #include <armnnDeserializer/IDeserializer.hpp>
8 
9 #include <armnn/utility/IgnoreUnused.hpp>
10 
11 #include <string>
12 
13 TEST_SUITE("Deserializer_Multiplication")
14 {
15 struct MultiplicationFixture : public ParserFlatbuffersSerializeFixture
16 {
MultiplicationFixtureMultiplicationFixture17     explicit MultiplicationFixture(const std::string & inputShape1,
18                                    const std::string & inputShape2,
19                                    const std::string & outputShape,
20                                    const std::string & dataType,
21                                    const std::string & activation="NONE")
22     {
23         armnn::IgnoreUnused(activation);
24         m_JsonString = R"(
25         {
26                 inputIds: [0, 1],
27                 outputIds: [3],
28                 layers: [
29                 {
30                     layer_type: "InputLayer",
31                     layer: {
32                           base: {
33                                 layerBindingId: 0,
34                                 base: {
35                                     index: 0,
36                                     layerName: "InputLayer1",
37                                     layerType: "Input",
38                                     inputSlots: [{
39                                         index: 0,
40                                         connection: {sourceLayerIndex:0, outputSlotIndex:0 },
41                                     }],
42                                     outputSlots: [ {
43                                         index: 0,
44                                         tensorInfo: {
45                                             dimensions: )" + inputShape1 + R"(,
46                                             dataType: )" + dataType + R"(
47                                         },
48                                     }],
49                                  },}},
50                 },
51                 {
52                 layer_type: "InputLayer",
53                 layer: {
54                        base: {
55                             layerBindingId: 1,
56                             base: {
57                                   index:1,
58                                   layerName: "InputLayer2",
59                                   layerType: "Input",
60                                   inputSlots: [{
61                                       index: 0,
62                                       connection: {sourceLayerIndex:0, outputSlotIndex:0 },
63                                   }],
64                                   outputSlots: [ {
65                                       index: 0,
66                                       tensorInfo: {
67                                           dimensions: )" + inputShape2 + R"(,
68                                           dataType: )" + dataType + R"(
69                                       },
70                                   }],
71                                 },}},
72                 },
73                 {
74                 layer_type: "MultiplicationLayer",
75                 layer : {
76                         base: {
77                              index:2,
78                              layerName: "MultiplicationLayer",
79                              layerType: "Multiplication",
80                              inputSlots: [
81                                             {
82                                              index: 0,
83                                              connection: {sourceLayerIndex:0, outputSlotIndex:0 },
84                                             },
85                                             {
86                                              index: 1,
87                                              connection: {sourceLayerIndex:1, outputSlotIndex:0 },
88                                             }
89                              ],
90                              outputSlots: [ {
91                                  index: 0,
92                                  tensorInfo: {
93                                      dimensions: )" + outputShape + R"(,
94                                      dataType: )" + dataType + R"(
95                                  },
96                              }],
97                             }},
98                 },
99                 {
100                 layer_type: "OutputLayer",
101                 layer: {
102                         base:{
103                               layerBindingId: 0,
104                               base: {
105                                     index: 3,
106                                     layerName: "OutputLayer",
107                                     layerType: "Output",
108                                     inputSlots: [{
109                                         index: 0,
110                                         connection: {sourceLayerIndex:2, outputSlotIndex:0 },
111                                     }],
112                                     outputSlots: [ {
113                                         index: 0,
114                                         tensorInfo: {
115                                             dimensions: )" + outputShape + R"(,
116                                             dataType: )" + dataType + R"(
117                                         },
118                                 }],
119                             }}},
120                 }]
121          }
122         )";
123         Setup();
124     }
125 };
126 
127 
128 struct SimpleMultiplicationFixture : MultiplicationFixture
129 {
SimpleMultiplicationFixtureSimpleMultiplicationFixture130     SimpleMultiplicationFixture() : MultiplicationFixture("[ 2, 2 ]",
131                                                           "[ 2, 2 ]",
132                                                           "[ 2, 2 ]",
133                                                           "QuantisedAsymm8") {}
134 };
135 
136 struct SimpleMultiplicationFixture2 : MultiplicationFixture
137 {
SimpleMultiplicationFixture2SimpleMultiplicationFixture2138     SimpleMultiplicationFixture2() : MultiplicationFixture("[ 2, 2, 1, 1 ]",
139                                                            "[ 2, 2, 1, 1 ]",
140                                                            "[ 2, 2, 1, 1 ]",
141                                                            "Float32") {}
142 };
143 
144 TEST_CASE_FIXTURE(SimpleMultiplicationFixture, "MultiplicationQuantisedAsymm8")
145 {
146   RunTest<2, armnn::DataType::QAsymmU8>(
147       0,
148       {{"InputLayer1", { 0, 1, 2, 3 }},
149       {"InputLayer2", { 4, 5, 6, 7 }}},
150       {{"OutputLayer", { 0, 5, 12, 21 }}});
151 }
152 
153 TEST_CASE_FIXTURE(SimpleMultiplicationFixture2, "MultiplicationFloat32")
154 {
155     RunTest<4, armnn::DataType::Float32>(
156     0,
157     {{"InputLayer1", { 100, 40, 226, 9 }},
158     {"InputLayer2", {   5,   8,  1, 12 }}},
159     {{"OutputLayer", { 500, 320, 226, 108 }}});
160 }
161 
162 }
163