xref: /aosp_15_r20/external/armnn/src/armnnDeserializer/test/DeserializeBatchNormalization.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "ParserFlatbuffersSerializeFixture.hpp"
7 #include <armnnDeserializer/IDeserializer.hpp>
8 
9 #include <string>
10 
11 TEST_SUITE("Deserializer_BatchNormalization")
12 {
13 struct BatchNormalizationFixture : public ParserFlatbuffersSerializeFixture
14 {
BatchNormalizationFixtureBatchNormalizationFixture15     explicit BatchNormalizationFixture(const std::string &inputShape,
16                                        const std::string &outputShape,
17                                        const std::string &meanShape,
18                                        const std::string &varianceShape,
19                                        const std::string &offsetShape,
20                                        const std::string &scaleShape,
21                                        const std::string &dataType,
22                                        const std::string &dataLayout)
23     {
24         m_JsonString = R"(
25     {
26         inputIds: [0],
27         outputIds: [2],
28         layers: [
29            {
30             layer_type: "InputLayer",
31             layer: {
32                 base: {
33                     layerBindingId: 0,
34                     base: {
35                         index: 0,
36                         layerName: "InputLayer",
37                         layerType: "Input",
38                         inputSlots: [{
39                             index: 0,
40                             connection: {sourceLayerIndex:0, outputSlotIndex:0 },
41                             }],
42                         outputSlots: [{
43                             index: 0,
44                             tensorInfo: {
45                                 dimensions: )" + inputShape + R"(,
46                                 dataType: ")" + dataType + R"(",
47                                 quantizationScale: 0.5,
48                                 quantizationOffset: 0
49                                 },
50                             }]
51                         },
52                     }
53                 },
54             },
55         {
56         layer_type: "BatchNormalizationLayer",
57         layer : {
58             base: {
59                 index:1,
60                 layerName: "BatchNormalizationLayer",
61                 layerType: "BatchNormalization",
62                 inputSlots: [{
63                         index: 0,
64                         connection: {sourceLayerIndex:0, outputSlotIndex:0 },
65                    }],
66                 outputSlots: [{
67                     index: 0,
68                     tensorInfo: {
69                         dimensions: )" + outputShape + R"(,
70                         dataType: ")" + dataType + R"("
71                     },
72                     }],
73                 },
74             descriptor: {
75                 eps: 0.0010000000475,
76                 dataLayout: ")" + dataLayout + R"("
77                 },
78             mean: {
79                 info: {
80                          dimensions: )" + meanShape + R"(,
81                          dataType: ")" + dataType + R"("
82                      },
83                 data_type: IntData,
84                 data: {
85                     data: [1084227584],
86                     }
87                 },
88             variance: {
89                 info: {
90                          dimensions: )" + varianceShape + R"(,
91                          dataType: ")" + dataType + R"("
92                      },
93                data_type: IntData,
94                 data: {
95                     data: [1073741824],
96                     }
97                 },
98             beta: {
99                 info: {
100                          dimensions: )" + offsetShape + R"(,
101                          dataType: ")" + dataType + R"("
102                      },
103                 data_type: IntData,
104                 data: {
105                     data: [0],
106                     }
107                 },
108             gamma: {
109                 info: {
110                          dimensions: )" + scaleShape + R"(,
111                          dataType: ")" + dataType + R"("
112                      },
113                 data_type: IntData,
114                 data: {
115                     data: [1065353216],
116                     }
117                 },
118             },
119         },
120         {
121         layer_type: "OutputLayer",
122         layer: {
123             base:{
124                 layerBindingId: 0,
125                 base: {
126                     index: 2,
127                     layerName: "OutputLayer",
128                     layerType: "Output",
129                     inputSlots: [{
130                         index: 0,
131                         connection: {sourceLayerIndex:1, outputSlotIndex:0 },
132                     }],
133                     outputSlots: [ {
134                         index: 0,
135                         tensorInfo: {
136                             dimensions: )" + outputShape + R"(,
137                             dataType: ")" + dataType + R"("
138                         },
139                     }],
140                 }
141             }},
142         }]
143     }
144 )";
145         Setup();
146     }
147 };
148 
149 struct BatchNormFixture : BatchNormalizationFixture
150 {
BatchNormFixtureBatchNormFixture151     BatchNormFixture():BatchNormalizationFixture("[ 1, 3, 3, 1 ]",
152                                                  "[ 1, 3, 3, 1 ]",
153                                                  "[ 1 ]",
154                                                  "[ 1 ]",
155                                                  "[ 1 ]",
156                                                  "[ 1 ]",
157                                                  "Float32",
158                                                  "NHWC"){}
159 };
160 
161 TEST_CASE_FIXTURE(BatchNormFixture, "BatchNormalizationFloat32")
162 {
163     RunTest<4, armnn::DataType::Float32>(0,
164                                          {{"InputLayer", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f }}},
165                                          {{"OutputLayer",{ -2.8277204f, -2.12079024f, -1.4138602f,
166                                            -0.7069301f,  0.0f,         0.7069301f,
167                                            1.4138602f,  2.12079024f,  2.8277204f }}});
168 }
169 
170 }
171