xref: /aosp_15_r20/external/armnn/src/armnnDeserializer/test/DeserializeInstanceNormalization.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2019 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "ParserFlatbuffersSerializeFixture.hpp"
7*89c4ff92SAndroid Build Coastguard Worker #include <armnnDeserializer/IDeserializer.hpp>
8*89c4ff92SAndroid Build Coastguard Worker 
9*89c4ff92SAndroid Build Coastguard Worker #include <string>
10*89c4ff92SAndroid Build Coastguard Worker 
11*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("Deserializer_InstanceNormalization")
12*89c4ff92SAndroid Build Coastguard Worker {
13*89c4ff92SAndroid Build Coastguard Worker struct InstanceNormalizationFixture : public ParserFlatbuffersSerializeFixture
14*89c4ff92SAndroid Build Coastguard Worker {
InstanceNormalizationFixtureInstanceNormalizationFixture15*89c4ff92SAndroid Build Coastguard Worker     explicit InstanceNormalizationFixture(const std::string &inputShape,
16*89c4ff92SAndroid Build Coastguard Worker                                           const std::string &outputShape,
17*89c4ff92SAndroid Build Coastguard Worker                                           const std::string &gamma,
18*89c4ff92SAndroid Build Coastguard Worker                                           const std::string &beta,
19*89c4ff92SAndroid Build Coastguard Worker                                           const std::string &epsilon,
20*89c4ff92SAndroid Build Coastguard Worker                                           const std::string &dataType,
21*89c4ff92SAndroid Build Coastguard Worker                                           const std::string &dataLayout)
22*89c4ff92SAndroid Build Coastguard Worker     {
23*89c4ff92SAndroid Build Coastguard Worker         m_JsonString = R"(
24*89c4ff92SAndroid Build Coastguard Worker     {
25*89c4ff92SAndroid Build Coastguard Worker         inputIds: [0],
26*89c4ff92SAndroid Build Coastguard Worker         outputIds: [2],
27*89c4ff92SAndroid Build Coastguard Worker         layers: [
28*89c4ff92SAndroid Build Coastguard Worker            {
29*89c4ff92SAndroid Build Coastguard Worker             layer_type: "InputLayer",
30*89c4ff92SAndroid Build Coastguard Worker             layer: {
31*89c4ff92SAndroid Build Coastguard Worker                 base: {
32*89c4ff92SAndroid Build Coastguard Worker                     layerBindingId: 0,
33*89c4ff92SAndroid Build Coastguard Worker                     base: {
34*89c4ff92SAndroid Build Coastguard Worker                         index: 0,
35*89c4ff92SAndroid Build Coastguard Worker                         layerName: "InputLayer",
36*89c4ff92SAndroid Build Coastguard Worker                         layerType: "Input",
37*89c4ff92SAndroid Build Coastguard Worker                         inputSlots: [{
38*89c4ff92SAndroid Build Coastguard Worker                             index: 0,
39*89c4ff92SAndroid Build Coastguard Worker                             connection: {sourceLayerIndex:0, outputSlotIndex:0 },
40*89c4ff92SAndroid Build Coastguard Worker                             }],
41*89c4ff92SAndroid Build Coastguard Worker                         outputSlots: [{
42*89c4ff92SAndroid Build Coastguard Worker                             index: 0,
43*89c4ff92SAndroid Build Coastguard Worker                             tensorInfo: {
44*89c4ff92SAndroid Build Coastguard Worker                                 dimensions: )" + inputShape + R"(,
45*89c4ff92SAndroid Build Coastguard Worker                                 dataType: ")" + dataType + R"(",
46*89c4ff92SAndroid Build Coastguard Worker                                 quantizationScale: 0.5,
47*89c4ff92SAndroid Build Coastguard Worker                                 quantizationOffset: 0
48*89c4ff92SAndroid Build Coastguard Worker                                 },
49*89c4ff92SAndroid Build Coastguard Worker                             }]
50*89c4ff92SAndroid Build Coastguard Worker                         },
51*89c4ff92SAndroid Build Coastguard Worker                     }
52*89c4ff92SAndroid Build Coastguard Worker                 },
53*89c4ff92SAndroid Build Coastguard Worker             },
54*89c4ff92SAndroid Build Coastguard Worker         {
55*89c4ff92SAndroid Build Coastguard Worker         layer_type: "InstanceNormalizationLayer",
56*89c4ff92SAndroid Build Coastguard Worker         layer : {
57*89c4ff92SAndroid Build Coastguard Worker             base: {
58*89c4ff92SAndroid Build Coastguard Worker                 index:1,
59*89c4ff92SAndroid Build Coastguard Worker                 layerName: "InstanceNormalizationLayer",
60*89c4ff92SAndroid Build Coastguard Worker                 layerType: "InstanceNormalization",
61*89c4ff92SAndroid Build Coastguard Worker                 inputSlots: [{
62*89c4ff92SAndroid Build Coastguard Worker                         index: 0,
63*89c4ff92SAndroid Build Coastguard Worker                         connection: {sourceLayerIndex:0, outputSlotIndex:0 },
64*89c4ff92SAndroid Build Coastguard Worker                    }],
65*89c4ff92SAndroid Build Coastguard Worker                 outputSlots: [{
66*89c4ff92SAndroid Build Coastguard Worker                     index: 0,
67*89c4ff92SAndroid Build Coastguard Worker                     tensorInfo: {
68*89c4ff92SAndroid Build Coastguard Worker                         dimensions: )" + outputShape + R"(,
69*89c4ff92SAndroid Build Coastguard Worker                         dataType: ")" + dataType + R"("
70*89c4ff92SAndroid Build Coastguard Worker                     },
71*89c4ff92SAndroid Build Coastguard Worker                     }],
72*89c4ff92SAndroid Build Coastguard Worker                 },
73*89c4ff92SAndroid Build Coastguard Worker             descriptor: {
74*89c4ff92SAndroid Build Coastguard Worker                 dataLayout: ")" + dataLayout + R"(",
75*89c4ff92SAndroid Build Coastguard Worker                 gamma: ")" + gamma + R"(",
76*89c4ff92SAndroid Build Coastguard Worker                 beta: ")" + beta + R"(",
77*89c4ff92SAndroid Build Coastguard Worker                 eps: )" + epsilon + R"(
78*89c4ff92SAndroid Build Coastguard Worker                 },
79*89c4ff92SAndroid Build Coastguard Worker             },
80*89c4ff92SAndroid Build Coastguard Worker         },
81*89c4ff92SAndroid Build Coastguard Worker         {
82*89c4ff92SAndroid Build Coastguard Worker         layer_type: "OutputLayer",
83*89c4ff92SAndroid Build Coastguard Worker         layer: {
84*89c4ff92SAndroid Build Coastguard Worker             base:{
85*89c4ff92SAndroid Build Coastguard Worker                 layerBindingId: 0,
86*89c4ff92SAndroid Build Coastguard Worker                 base: {
87*89c4ff92SAndroid Build Coastguard Worker                     index: 2,
88*89c4ff92SAndroid Build Coastguard Worker                     layerName: "OutputLayer",
89*89c4ff92SAndroid Build Coastguard Worker                     layerType: "Output",
90*89c4ff92SAndroid Build Coastguard Worker                     inputSlots: [{
91*89c4ff92SAndroid Build Coastguard Worker                         index: 0,
92*89c4ff92SAndroid Build Coastguard Worker                         connection: {sourceLayerIndex:1, outputSlotIndex:0 },
93*89c4ff92SAndroid Build Coastguard Worker                     }],
94*89c4ff92SAndroid Build Coastguard Worker                     outputSlots: [ {
95*89c4ff92SAndroid Build Coastguard Worker                         index: 0,
96*89c4ff92SAndroid Build Coastguard Worker                         tensorInfo: {
97*89c4ff92SAndroid Build Coastguard Worker                             dimensions: )" + outputShape + R"(,
98*89c4ff92SAndroid Build Coastguard Worker                             dataType: ")" + dataType + R"("
99*89c4ff92SAndroid Build Coastguard Worker                         },
100*89c4ff92SAndroid Build Coastguard Worker                     }],
101*89c4ff92SAndroid Build Coastguard Worker                 }
102*89c4ff92SAndroid Build Coastguard Worker             }},
103*89c4ff92SAndroid Build Coastguard Worker         }]
104*89c4ff92SAndroid Build Coastguard Worker     }
105*89c4ff92SAndroid Build Coastguard Worker )";
106*89c4ff92SAndroid Build Coastguard Worker         SetupSingleInputSingleOutput("InputLayer", "OutputLayer");
107*89c4ff92SAndroid Build Coastguard Worker     }
108*89c4ff92SAndroid Build Coastguard Worker };
109*89c4ff92SAndroid Build Coastguard Worker 
110*89c4ff92SAndroid Build Coastguard Worker struct InstanceNormalizationFloat32Fixture : InstanceNormalizationFixture
111*89c4ff92SAndroid Build Coastguard Worker {
InstanceNormalizationFloat32FixtureInstanceNormalizationFloat32Fixture112*89c4ff92SAndroid Build Coastguard Worker     InstanceNormalizationFloat32Fixture():InstanceNormalizationFixture("[ 2, 2, 2, 2 ]",
113*89c4ff92SAndroid Build Coastguard Worker                                                                        "[ 2, 2, 2, 2 ]",
114*89c4ff92SAndroid Build Coastguard Worker                                                                        "1.0",
115*89c4ff92SAndroid Build Coastguard Worker                                                                        "0.0",
116*89c4ff92SAndroid Build Coastguard Worker                                                                        "0.0001",
117*89c4ff92SAndroid Build Coastguard Worker                                                                        "Float32",
118*89c4ff92SAndroid Build Coastguard Worker                                                                        "NHWC") {}
119*89c4ff92SAndroid Build Coastguard Worker };
120*89c4ff92SAndroid Build Coastguard Worker 
121*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(InstanceNormalizationFloat32Fixture, "InstanceNormalizationFloat32")
122*89c4ff92SAndroid Build Coastguard Worker {
123*89c4ff92SAndroid Build Coastguard Worker     RunTest<4, armnn::DataType::Float32>(
124*89c4ff92SAndroid Build Coastguard Worker         0,
125*89c4ff92SAndroid Build Coastguard Worker          {
126*89c4ff92SAndroid Build Coastguard Worker              0.f,  1.f,
127*89c4ff92SAndroid Build Coastguard Worker              0.f,  2.f,
128*89c4ff92SAndroid Build Coastguard Worker 
129*89c4ff92SAndroid Build Coastguard Worker              0.f,  2.f,
130*89c4ff92SAndroid Build Coastguard Worker              0.f,  4.f,
131*89c4ff92SAndroid Build Coastguard Worker 
132*89c4ff92SAndroid Build Coastguard Worker              1.f, -1.f,
133*89c4ff92SAndroid Build Coastguard Worker             -1.f,  2.f,
134*89c4ff92SAndroid Build Coastguard Worker 
135*89c4ff92SAndroid Build Coastguard Worker             -1.f, -2.f,
136*89c4ff92SAndroid Build Coastguard Worker              1.f,  4.f
137*89c4ff92SAndroid Build Coastguard Worker         },
138*89c4ff92SAndroid Build Coastguard Worker         {
139*89c4ff92SAndroid Build Coastguard Worker              0.0000000f, -1.1470304f,
140*89c4ff92SAndroid Build Coastguard Worker              0.0000000f, -0.2294061f,
141*89c4ff92SAndroid Build Coastguard Worker 
142*89c4ff92SAndroid Build Coastguard Worker              0.0000000f, -0.2294061f,
143*89c4ff92SAndroid Build Coastguard Worker              0.0000000f,  1.6058424f,
144*89c4ff92SAndroid Build Coastguard Worker 
145*89c4ff92SAndroid Build Coastguard Worker              0.9999501f, -0.7337929f,
146*89c4ff92SAndroid Build Coastguard Worker             -0.9999501f,  0.5241377f,
147*89c4ff92SAndroid Build Coastguard Worker 
148*89c4ff92SAndroid Build Coastguard Worker             -0.9999501f, -1.1531031f,
149*89c4ff92SAndroid Build Coastguard Worker              0.9999501f,  1.3627582f
150*89c4ff92SAndroid Build Coastguard Worker         });
151*89c4ff92SAndroid Build Coastguard Worker }
152*89c4ff92SAndroid Build Coastguard Worker 
153*89c4ff92SAndroid Build Coastguard Worker }
154