xref: /aosp_15_r20/external/armnn/src/armnnDeserializer/test/DeserializeSubtraction.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "ParserFlatbuffersSerializeFixture.hpp"
7*89c4ff92SAndroid Build Coastguard Worker #include <armnnDeserializer/IDeserializer.hpp>
8*89c4ff92SAndroid Build Coastguard Worker 
9*89c4ff92SAndroid Build Coastguard Worker #include <string>
10*89c4ff92SAndroid Build Coastguard Worker 
11*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("Deserializer_Subtraction")
12*89c4ff92SAndroid Build Coastguard Worker {
13*89c4ff92SAndroid Build Coastguard Worker struct SubtractionFixture : public ParserFlatbuffersSerializeFixture
14*89c4ff92SAndroid Build Coastguard Worker {
SubtractionFixtureSubtractionFixture15*89c4ff92SAndroid Build Coastguard Worker     explicit SubtractionFixture(const std::string & inputShape1,
16*89c4ff92SAndroid Build Coastguard Worker                                 const std::string & inputShape2,
17*89c4ff92SAndroid Build Coastguard Worker                                 const std::string & outputShape,
18*89c4ff92SAndroid Build Coastguard Worker                                 const std::string & dataType)
19*89c4ff92SAndroid Build Coastguard Worker     {
20*89c4ff92SAndroid Build Coastguard Worker         m_JsonString = R"(
21*89c4ff92SAndroid Build Coastguard Worker         {
22*89c4ff92SAndroid Build Coastguard Worker             inputIds: [0, 1],
23*89c4ff92SAndroid Build Coastguard Worker             outputIds: [3],
24*89c4ff92SAndroid Build Coastguard Worker             layers: [
25*89c4ff92SAndroid Build Coastguard Worker             {
26*89c4ff92SAndroid Build Coastguard Worker                 layer_type: "InputLayer",
27*89c4ff92SAndroid Build Coastguard Worker                 layer: {
28*89c4ff92SAndroid Build Coastguard Worker                        base: {
29*89c4ff92SAndroid Build Coastguard Worker                              layerBindingId: 0,
30*89c4ff92SAndroid Build Coastguard Worker                              base: {
31*89c4ff92SAndroid Build Coastguard Worker                                    index: 0,
32*89c4ff92SAndroid Build Coastguard Worker                                    layerName: "inputLayer1",
33*89c4ff92SAndroid Build Coastguard Worker                                    layerType: "Input",
34*89c4ff92SAndroid Build Coastguard Worker                                    inputSlots: [{
35*89c4ff92SAndroid Build Coastguard Worker                                        index: 0,
36*89c4ff92SAndroid Build Coastguard Worker                                        connection: {sourceLayerIndex:0, outputSlotIndex:0 },
37*89c4ff92SAndroid Build Coastguard Worker                                    }],
38*89c4ff92SAndroid Build Coastguard Worker                                    outputSlots: [ {
39*89c4ff92SAndroid Build Coastguard Worker                                        index: 0,
40*89c4ff92SAndroid Build Coastguard Worker                                        tensorInfo: {
41*89c4ff92SAndroid Build Coastguard Worker                                            dimensions: )" + inputShape1 + R"(,
42*89c4ff92SAndroid Build Coastguard Worker                                            dataType: )" + dataType + R"(
43*89c4ff92SAndroid Build Coastguard Worker                                        },
44*89c4ff92SAndroid Build Coastguard Worker                                    }],
45*89c4ff92SAndroid Build Coastguard Worker                               },
46*89c4ff92SAndroid Build Coastguard Worker                        }},
47*89c4ff92SAndroid Build Coastguard Worker             },
48*89c4ff92SAndroid Build Coastguard Worker             {
49*89c4ff92SAndroid Build Coastguard Worker             layer_type: "InputLayer",
50*89c4ff92SAndroid Build Coastguard Worker             layer: {
51*89c4ff92SAndroid Build Coastguard Worker                    base: {
52*89c4ff92SAndroid Build Coastguard Worker                          layerBindingId: 1,
53*89c4ff92SAndroid Build Coastguard Worker                          base: {
54*89c4ff92SAndroid Build Coastguard Worker                                index:1,
55*89c4ff92SAndroid Build Coastguard Worker                                layerName: "inputLayer2",
56*89c4ff92SAndroid Build Coastguard Worker                                layerType: "Input",
57*89c4ff92SAndroid Build Coastguard Worker                                inputSlots: [{
58*89c4ff92SAndroid Build Coastguard Worker                                    index: 0,
59*89c4ff92SAndroid Build Coastguard Worker                                    connection: {sourceLayerIndex:0, outputSlotIndex:0 },
60*89c4ff92SAndroid Build Coastguard Worker                                }],
61*89c4ff92SAndroid Build Coastguard Worker                                outputSlots: [ {
62*89c4ff92SAndroid Build Coastguard Worker                                    index: 0,
63*89c4ff92SAndroid Build Coastguard Worker                                    tensorInfo: {
64*89c4ff92SAndroid Build Coastguard Worker                                        dimensions: )" + inputShape2 + R"(,
65*89c4ff92SAndroid Build Coastguard Worker                                        dataType: )" + dataType + R"(
66*89c4ff92SAndroid Build Coastguard Worker                                    },
67*89c4ff92SAndroid Build Coastguard Worker                                }],
68*89c4ff92SAndroid Build Coastguard Worker                          },
69*89c4ff92SAndroid Build Coastguard Worker                    }},
70*89c4ff92SAndroid Build Coastguard Worker             },
71*89c4ff92SAndroid Build Coastguard Worker             {
72*89c4ff92SAndroid Build Coastguard Worker             layer_type: "SubtractionLayer",
73*89c4ff92SAndroid Build Coastguard Worker             layer : {
74*89c4ff92SAndroid Build Coastguard Worker                     base: {
75*89c4ff92SAndroid Build Coastguard Worker                           index:2,
76*89c4ff92SAndroid Build Coastguard Worker                           layerName: "subtractionLayer",
77*89c4ff92SAndroid Build Coastguard Worker                           layerType: "Subtraction",
78*89c4ff92SAndroid Build Coastguard Worker                           inputSlots: [{
79*89c4ff92SAndroid Build Coastguard Worker                               index: 0,
80*89c4ff92SAndroid Build Coastguard Worker                               connection: {sourceLayerIndex:0, outputSlotIndex:0 },
81*89c4ff92SAndroid Build Coastguard Worker                           },
82*89c4ff92SAndroid Build Coastguard Worker                           {
83*89c4ff92SAndroid Build Coastguard Worker                               index: 1,
84*89c4ff92SAndroid Build Coastguard Worker                               connection: {sourceLayerIndex:1, outputSlotIndex:0 },
85*89c4ff92SAndroid Build Coastguard Worker                           }],
86*89c4ff92SAndroid Build Coastguard Worker                           outputSlots: [ {
87*89c4ff92SAndroid Build Coastguard Worker                               index: 0,
88*89c4ff92SAndroid Build Coastguard Worker                               tensorInfo: {
89*89c4ff92SAndroid Build Coastguard Worker                                   dimensions: )" + outputShape + R"(,
90*89c4ff92SAndroid Build Coastguard Worker                                   dataType: )" + dataType + R"(
91*89c4ff92SAndroid Build Coastguard Worker                               },
92*89c4ff92SAndroid Build Coastguard Worker                           }],
93*89c4ff92SAndroid Build Coastguard Worker                     }},
94*89c4ff92SAndroid Build Coastguard Worker             },
95*89c4ff92SAndroid Build Coastguard Worker             {
96*89c4ff92SAndroid Build Coastguard Worker             layer_type: "OutputLayer",
97*89c4ff92SAndroid Build Coastguard Worker             layer: {
98*89c4ff92SAndroid Build Coastguard Worker                    base:{
99*89c4ff92SAndroid Build Coastguard Worker                          layerBindingId: 0,
100*89c4ff92SAndroid Build Coastguard Worker                          base: {
101*89c4ff92SAndroid Build Coastguard Worker                                index: 3,
102*89c4ff92SAndroid Build Coastguard Worker                                layerName: "outputLayer",
103*89c4ff92SAndroid Build Coastguard Worker                                layerType: "Output",
104*89c4ff92SAndroid Build Coastguard Worker                                inputSlots: [{
105*89c4ff92SAndroid Build Coastguard Worker                                    index: 0,
106*89c4ff92SAndroid Build Coastguard Worker                                    connection: {sourceLayerIndex:2, outputSlotIndex:0 },
107*89c4ff92SAndroid Build Coastguard Worker                                }],
108*89c4ff92SAndroid Build Coastguard Worker                                outputSlots: [ {
109*89c4ff92SAndroid Build Coastguard Worker                                    index: 0,
110*89c4ff92SAndroid Build Coastguard Worker                                    tensorInfo: {
111*89c4ff92SAndroid Build Coastguard Worker                                        dimensions: )" + outputShape + R"(,
112*89c4ff92SAndroid Build Coastguard Worker                                        dataType: )" + dataType + R"(
113*89c4ff92SAndroid Build Coastguard Worker                                    },
114*89c4ff92SAndroid Build Coastguard Worker                                }],
115*89c4ff92SAndroid Build Coastguard Worker                         }}},
116*89c4ff92SAndroid Build Coastguard Worker             }]
117*89c4ff92SAndroid Build Coastguard Worker         }
118*89c4ff92SAndroid Build Coastguard Worker         )";
119*89c4ff92SAndroid Build Coastguard Worker         Setup();
120*89c4ff92SAndroid Build Coastguard Worker     }
121*89c4ff92SAndroid Build Coastguard Worker };
122*89c4ff92SAndroid Build Coastguard Worker 
123*89c4ff92SAndroid Build Coastguard Worker struct SimpleSubtractionFixture : SubtractionFixture
124*89c4ff92SAndroid Build Coastguard Worker {
SimpleSubtractionFixtureSimpleSubtractionFixture125*89c4ff92SAndroid Build Coastguard Worker     SimpleSubtractionFixture() : SubtractionFixture("[ 1, 4 ]",
126*89c4ff92SAndroid Build Coastguard Worker                                                     "[ 1, 4 ]",
127*89c4ff92SAndroid Build Coastguard Worker                                                     "[ 1, 4 ]",
128*89c4ff92SAndroid Build Coastguard Worker                                                     "QuantisedAsymm8") {}
129*89c4ff92SAndroid Build Coastguard Worker };
130*89c4ff92SAndroid Build Coastguard Worker 
131*89c4ff92SAndroid Build Coastguard Worker struct SimpleSubtractionFixture2 : SubtractionFixture
132*89c4ff92SAndroid Build Coastguard Worker {
SimpleSubtractionFixture2SimpleSubtractionFixture2133*89c4ff92SAndroid Build Coastguard Worker     SimpleSubtractionFixture2() : SubtractionFixture("[ 1, 4 ]",
134*89c4ff92SAndroid Build Coastguard Worker                                                      "[ 1, 4 ]",
135*89c4ff92SAndroid Build Coastguard Worker                                                      "[ 1, 4 ]",
136*89c4ff92SAndroid Build Coastguard Worker                                                      "Float32") {}
137*89c4ff92SAndroid Build Coastguard Worker };
138*89c4ff92SAndroid Build Coastguard Worker 
139*89c4ff92SAndroid Build Coastguard Worker struct SimpleSubtractionFixtureBroadcast : SubtractionFixture
140*89c4ff92SAndroid Build Coastguard Worker {
SimpleSubtractionFixtureBroadcastSimpleSubtractionFixtureBroadcast141*89c4ff92SAndroid Build Coastguard Worker     SimpleSubtractionFixtureBroadcast() : SubtractionFixture("[ 1, 4 ]",
142*89c4ff92SAndroid Build Coastguard Worker                                                              "[ 1, 1 ]",
143*89c4ff92SAndroid Build Coastguard Worker                                                              "[ 1, 4 ]",
144*89c4ff92SAndroid Build Coastguard Worker                                                              "Float32") {}
145*89c4ff92SAndroid Build Coastguard Worker };
146*89c4ff92SAndroid Build Coastguard Worker 
147*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(SimpleSubtractionFixture, "SubtractionQuantisedAsymm8")
148*89c4ff92SAndroid Build Coastguard Worker {
149*89c4ff92SAndroid Build Coastguard Worker     RunTest<2, armnn::DataType::QAsymmU8>(
150*89c4ff92SAndroid Build Coastguard Worker         0,
151*89c4ff92SAndroid Build Coastguard Worker         {{"inputLayer1", { 4, 5, 6, 7 }},
152*89c4ff92SAndroid Build Coastguard Worker          {"inputLayer2", { 3, 2, 1, 0 }}},
153*89c4ff92SAndroid Build Coastguard Worker         {{"outputLayer", { 1, 3, 5, 7 }}});
154*89c4ff92SAndroid Build Coastguard Worker }
155*89c4ff92SAndroid Build Coastguard Worker 
156*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(SimpleSubtractionFixture2, "SubtractionFloat32")
157*89c4ff92SAndroid Build Coastguard Worker {
158*89c4ff92SAndroid Build Coastguard Worker     RunTest<2, armnn::DataType::Float32>(
159*89c4ff92SAndroid Build Coastguard Worker         0,
160*89c4ff92SAndroid Build Coastguard Worker         {{"inputLayer1", { 4, 5, 6, 7 }},
161*89c4ff92SAndroid Build Coastguard Worker          {"inputLayer2", { 3, 2, 1, 0 }}},
162*89c4ff92SAndroid Build Coastguard Worker         {{"outputLayer", { 1, 3, 5, 7 }}});
163*89c4ff92SAndroid Build Coastguard Worker }
164*89c4ff92SAndroid Build Coastguard Worker 
165*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(SimpleSubtractionFixtureBroadcast, "SubtractionBroadcast")
166*89c4ff92SAndroid Build Coastguard Worker {
167*89c4ff92SAndroid Build Coastguard Worker     RunTest<2, armnn::DataType::Float32>(
168*89c4ff92SAndroid Build Coastguard Worker         0,
169*89c4ff92SAndroid Build Coastguard Worker         {{"inputLayer1", { 4, 5, 6, 7 }},
170*89c4ff92SAndroid Build Coastguard Worker          {"inputLayer2", { 2 }}},
171*89c4ff92SAndroid Build Coastguard Worker         {{"outputLayer", { 2, 3, 4, 5 }}});
172*89c4ff92SAndroid Build Coastguard Worker }
173*89c4ff92SAndroid Build Coastguard Worker 
174*89c4ff92SAndroid Build Coastguard Worker }
175