xref: /aosp_15_r20/external/armnn/src/armnnDeserializer/test/DeserializeActivation.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "ParserFlatbuffersSerializeFixture.hpp"
7*89c4ff92SAndroid Build Coastguard Worker #include <armnnDeserializer/IDeserializer.hpp>
8*89c4ff92SAndroid Build Coastguard Worker 
9*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
10*89c4ff92SAndroid Build Coastguard Worker 
11*89c4ff92SAndroid Build Coastguard Worker #include <string>
12*89c4ff92SAndroid Build Coastguard Worker 
13*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("DeserializeParser_Activation")
14*89c4ff92SAndroid Build Coastguard Worker {
15*89c4ff92SAndroid Build Coastguard Worker struct ActivationFixture : public ParserFlatbuffersSerializeFixture
16*89c4ff92SAndroid Build Coastguard Worker {
ActivationFixtureActivationFixture17*89c4ff92SAndroid Build Coastguard Worker     explicit ActivationFixture(const std::string& inputShape,
18*89c4ff92SAndroid Build Coastguard Worker                                const std::string& outputShape,
19*89c4ff92SAndroid Build Coastguard Worker                                const std::string& dataType,
20*89c4ff92SAndroid Build Coastguard Worker                                const std::string& activationType="Sigmoid",
21*89c4ff92SAndroid Build Coastguard Worker                                const std::string& a = "0.0",
22*89c4ff92SAndroid Build Coastguard Worker                                const std::string& b = "0.0")
23*89c4ff92SAndroid Build Coastguard Worker     {
24*89c4ff92SAndroid Build Coastguard Worker         m_JsonString = R"(
25*89c4ff92SAndroid Build Coastguard Worker         {
26*89c4ff92SAndroid Build Coastguard Worker             inputIds: [0],
27*89c4ff92SAndroid Build Coastguard Worker             outputIds: [2],
28*89c4ff92SAndroid Build Coastguard Worker             layers: [{
29*89c4ff92SAndroid Build Coastguard Worker                 layer_type: "InputLayer",
30*89c4ff92SAndroid Build Coastguard Worker                 layer: {
31*89c4ff92SAndroid Build Coastguard Worker                     base: {
32*89c4ff92SAndroid Build Coastguard Worker                         layerBindingId: 0,
33*89c4ff92SAndroid Build Coastguard Worker                         base: {
34*89c4ff92SAndroid Build Coastguard Worker                             index: 0,
35*89c4ff92SAndroid Build Coastguard Worker                             layerName: "InputLayer",
36*89c4ff92SAndroid Build Coastguard Worker                             layerType: "Input",
37*89c4ff92SAndroid Build Coastguard Worker                             inputSlots: [{
38*89c4ff92SAndroid Build Coastguard Worker                                 index: 0,
39*89c4ff92SAndroid Build Coastguard Worker                                 connection: {sourceLayerIndex:0, outputSlotIndex:0 },
40*89c4ff92SAndroid Build Coastguard Worker                             }],
41*89c4ff92SAndroid Build Coastguard Worker                             outputSlots: [{
42*89c4ff92SAndroid Build Coastguard Worker                                 index: 0,
43*89c4ff92SAndroid Build Coastguard Worker                                 tensorInfo: {
44*89c4ff92SAndroid Build Coastguard Worker                                     dimensions: )" + inputShape + R"(,
45*89c4ff92SAndroid Build Coastguard Worker                                     dataType: )" + dataType + R"(
46*89c4ff92SAndroid Build Coastguard Worker                                 },
47*89c4ff92SAndroid Build Coastguard Worker                             }],
48*89c4ff92SAndroid Build Coastguard Worker                         },
49*89c4ff92SAndroid Build Coastguard Worker                     }
50*89c4ff92SAndroid Build Coastguard Worker                 },
51*89c4ff92SAndroid Build Coastguard Worker             },
52*89c4ff92SAndroid Build Coastguard Worker             {
53*89c4ff92SAndroid Build Coastguard Worker                 layer_type: "ActivationLayer",
54*89c4ff92SAndroid Build Coastguard Worker                 layer : {
55*89c4ff92SAndroid Build Coastguard Worker                     base: {
56*89c4ff92SAndroid Build Coastguard Worker                         index:1,
57*89c4ff92SAndroid Build Coastguard Worker                         layerName: "ActivationLayer",
58*89c4ff92SAndroid Build Coastguard Worker                         layerType: "Activation",
59*89c4ff92SAndroid Build Coastguard Worker                         inputSlots: [{
60*89c4ff92SAndroid Build Coastguard Worker                             index: 0,
61*89c4ff92SAndroid Build Coastguard Worker                             connection: {sourceLayerIndex:0, outputSlotIndex:0 },
62*89c4ff92SAndroid Build Coastguard Worker                         }],
63*89c4ff92SAndroid Build Coastguard Worker                         outputSlots: [{
64*89c4ff92SAndroid Build Coastguard Worker                             index: 0,
65*89c4ff92SAndroid Build Coastguard Worker                             tensorInfo: {
66*89c4ff92SAndroid Build Coastguard Worker                                 dimensions: )" + outputShape + R"(,
67*89c4ff92SAndroid Build Coastguard Worker                                 dataType: )" + dataType + R"(
68*89c4ff92SAndroid Build Coastguard Worker                             },
69*89c4ff92SAndroid Build Coastguard Worker                         }],
70*89c4ff92SAndroid Build Coastguard Worker                     },
71*89c4ff92SAndroid Build Coastguard Worker                     descriptor: {
72*89c4ff92SAndroid Build Coastguard Worker                         a: )" + a + R"(,
73*89c4ff92SAndroid Build Coastguard Worker                         b: )" + b + R"(,
74*89c4ff92SAndroid Build Coastguard Worker                         activationFunction: )" + activationType + R"(
75*89c4ff92SAndroid Build Coastguard Worker                     },
76*89c4ff92SAndroid Build Coastguard Worker                 },
77*89c4ff92SAndroid Build Coastguard Worker             },
78*89c4ff92SAndroid Build Coastguard Worker             {
79*89c4ff92SAndroid Build Coastguard Worker                 layer_type: "OutputLayer",
80*89c4ff92SAndroid Build Coastguard Worker                 layer: {
81*89c4ff92SAndroid Build Coastguard Worker                     base:{
82*89c4ff92SAndroid Build Coastguard Worker                         layerBindingId: 2,
83*89c4ff92SAndroid Build Coastguard Worker                         base: {
84*89c4ff92SAndroid Build Coastguard Worker                             index: 2,
85*89c4ff92SAndroid Build Coastguard Worker                             layerName: "OutputLayer",
86*89c4ff92SAndroid Build Coastguard Worker                             layerType: "Output",
87*89c4ff92SAndroid Build Coastguard Worker                             inputSlots: [{
88*89c4ff92SAndroid Build Coastguard Worker                                 index: 0,
89*89c4ff92SAndroid Build Coastguard Worker                                 connection: {sourceLayerIndex:1, outputSlotIndex:0 },
90*89c4ff92SAndroid Build Coastguard Worker                             }],
91*89c4ff92SAndroid Build Coastguard Worker                             outputSlots: [{
92*89c4ff92SAndroid Build Coastguard Worker                                 index: 0,
93*89c4ff92SAndroid Build Coastguard Worker                                 tensorInfo: {
94*89c4ff92SAndroid Build Coastguard Worker                                     dimensions: )" + outputShape + R"(,
95*89c4ff92SAndroid Build Coastguard Worker                                     dataType: )" + dataType + R"(
96*89c4ff92SAndroid Build Coastguard Worker                                 },
97*89c4ff92SAndroid Build Coastguard Worker                             }],
98*89c4ff92SAndroid Build Coastguard Worker                         }
99*89c4ff92SAndroid Build Coastguard Worker                     }
100*89c4ff92SAndroid Build Coastguard Worker                 },
101*89c4ff92SAndroid Build Coastguard Worker             }]
102*89c4ff92SAndroid Build Coastguard Worker         }
103*89c4ff92SAndroid Build Coastguard Worker         )";
104*89c4ff92SAndroid Build Coastguard Worker         Setup();
105*89c4ff92SAndroid Build Coastguard Worker     }
106*89c4ff92SAndroid Build Coastguard Worker };
107*89c4ff92SAndroid Build Coastguard Worker 
108*89c4ff92SAndroid Build Coastguard Worker struct SimpleActivationFixture : ActivationFixture
109*89c4ff92SAndroid Build Coastguard Worker {
SimpleActivationFixtureSimpleActivationFixture110*89c4ff92SAndroid Build Coastguard Worker     SimpleActivationFixture() : ActivationFixture("[1, 2, 2, 1]",
111*89c4ff92SAndroid Build Coastguard Worker                                                   "[1, 2, 2, 1]",
112*89c4ff92SAndroid Build Coastguard Worker                                                   "QuantisedAsymm8",
113*89c4ff92SAndroid Build Coastguard Worker                                                   "ReLu") {}
114*89c4ff92SAndroid Build Coastguard Worker };
115*89c4ff92SAndroid Build Coastguard Worker 
116*89c4ff92SAndroid Build Coastguard Worker struct SimpleActivationFixture2 : ActivationFixture
117*89c4ff92SAndroid Build Coastguard Worker {
SimpleActivationFixture2SimpleActivationFixture2118*89c4ff92SAndroid Build Coastguard Worker     SimpleActivationFixture2() : ActivationFixture("[1, 2, 2, 1]",
119*89c4ff92SAndroid Build Coastguard Worker                                                    "[1, 2, 2, 1]",
120*89c4ff92SAndroid Build Coastguard Worker                                                    "Float32",
121*89c4ff92SAndroid Build Coastguard Worker                                                    "ReLu") {}
122*89c4ff92SAndroid Build Coastguard Worker };
123*89c4ff92SAndroid Build Coastguard Worker 
124*89c4ff92SAndroid Build Coastguard Worker struct SimpleActivationFixture3 : ActivationFixture
125*89c4ff92SAndroid Build Coastguard Worker {
SimpleActivationFixture3SimpleActivationFixture3126*89c4ff92SAndroid Build Coastguard Worker     SimpleActivationFixture3() : ActivationFixture("[1, 2, 2, 1]",
127*89c4ff92SAndroid Build Coastguard Worker                                                    "[1, 2, 2, 1]",
128*89c4ff92SAndroid Build Coastguard Worker                                                    "QuantisedAsymm8",
129*89c4ff92SAndroid Build Coastguard Worker                                                    "BoundedReLu",
130*89c4ff92SAndroid Build Coastguard Worker                                                    "5.0",
131*89c4ff92SAndroid Build Coastguard Worker                                                    "0.0") {}
132*89c4ff92SAndroid Build Coastguard Worker };
133*89c4ff92SAndroid Build Coastguard Worker 
134*89c4ff92SAndroid Build Coastguard Worker struct SimpleActivationFixture4 : ActivationFixture
135*89c4ff92SAndroid Build Coastguard Worker {
SimpleActivationFixture4SimpleActivationFixture4136*89c4ff92SAndroid Build Coastguard Worker     SimpleActivationFixture4() : ActivationFixture("[1, 2, 2, 1]",
137*89c4ff92SAndroid Build Coastguard Worker                                                    "[1, 2, 2, 1]",
138*89c4ff92SAndroid Build Coastguard Worker                                                    "Float32",
139*89c4ff92SAndroid Build Coastguard Worker                                                    "BoundedReLu",
140*89c4ff92SAndroid Build Coastguard Worker                                                    "5.0",
141*89c4ff92SAndroid Build Coastguard Worker                                                    "0.0") {}
142*89c4ff92SAndroid Build Coastguard Worker };
143*89c4ff92SAndroid Build Coastguard Worker 
144*89c4ff92SAndroid Build Coastguard Worker 
145*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(SimpleActivationFixture, "ActivationReluQuantisedAsymm8")
146*89c4ff92SAndroid Build Coastguard Worker {
147*89c4ff92SAndroid Build Coastguard Worker     RunTest<4, armnn::DataType::QAsymmU8>(
148*89c4ff92SAndroid Build Coastguard Worker             0,
149*89c4ff92SAndroid Build Coastguard Worker             {{"InputLayer", {10, 0, 2, 0}}},
150*89c4ff92SAndroid Build Coastguard Worker             {{"OutputLayer", {10, 0, 2, 0}}});
151*89c4ff92SAndroid Build Coastguard Worker }
152*89c4ff92SAndroid Build Coastguard Worker 
153*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(SimpleActivationFixture2, "ActivationReluFloat32")
154*89c4ff92SAndroid Build Coastguard Worker {
155*89c4ff92SAndroid Build Coastguard Worker     RunTest<4, armnn::DataType::Float32>(
156*89c4ff92SAndroid Build Coastguard Worker             0,
157*89c4ff92SAndroid Build Coastguard Worker             {{"InputLayer", {111, -85, 226, 3}}},
158*89c4ff92SAndroid Build Coastguard Worker             {{"OutputLayer", {111, 0, 226, 3}}});
159*89c4ff92SAndroid Build Coastguard Worker }
160*89c4ff92SAndroid Build Coastguard Worker 
161*89c4ff92SAndroid Build Coastguard Worker 
162*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(SimpleActivationFixture3, "ActivationBoundedReluQuantisedAsymm8")
163*89c4ff92SAndroid Build Coastguard Worker {
164*89c4ff92SAndroid Build Coastguard Worker     RunTest<4, armnn::DataType::QAsymmU8>(
165*89c4ff92SAndroid Build Coastguard Worker             0,
166*89c4ff92SAndroid Build Coastguard Worker             {{"InputLayer", {10, 0, 2, 0}}},
167*89c4ff92SAndroid Build Coastguard Worker             {{"OutputLayer", {5, 0, 2, 0}}});
168*89c4ff92SAndroid Build Coastguard Worker }
169*89c4ff92SAndroid Build Coastguard Worker 
170*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(SimpleActivationFixture4, "ActivationBoundedReluFloat32")
171*89c4ff92SAndroid Build Coastguard Worker {
172*89c4ff92SAndroid Build Coastguard Worker     RunTest<4, armnn::DataType::Float32>(
173*89c4ff92SAndroid Build Coastguard Worker             0,
174*89c4ff92SAndroid Build Coastguard Worker             {{"InputLayer", {111, -85, 226, 3}}},
175*89c4ff92SAndroid Build Coastguard Worker             {{"OutputLayer", {5, 0, 5, 3}}});
176*89c4ff92SAndroid Build Coastguard Worker }
177*89c4ff92SAndroid Build Coastguard Worker 
178*89c4ff92SAndroid Build Coastguard Worker }
179