xref: /aosp_15_r20/external/armnn/src/armnnDeserializer/test/DeserializeLogSoftmax.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2019 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "ParserFlatbuffersSerializeFixture.hpp"
7*89c4ff92SAndroid Build Coastguard Worker #include <armnnDeserializer/IDeserializer.hpp>
8*89c4ff92SAndroid Build Coastguard Worker 
9*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("Deserializer_LogSoftmax")
10*89c4ff92SAndroid Build Coastguard Worker {
11*89c4ff92SAndroid Build Coastguard Worker struct LogSoftmaxFixture : public ParserFlatbuffersSerializeFixture
12*89c4ff92SAndroid Build Coastguard Worker {
LogSoftmaxFixtureLogSoftmaxFixture13*89c4ff92SAndroid Build Coastguard Worker     explicit LogSoftmaxFixture(const std::string &shape,
14*89c4ff92SAndroid Build Coastguard Worker                                const std::string &beta,
15*89c4ff92SAndroid Build Coastguard Worker                                const std::string &axis,
16*89c4ff92SAndroid Build Coastguard Worker                                const std::string &dataType)
17*89c4ff92SAndroid Build Coastguard Worker     {
18*89c4ff92SAndroid Build Coastguard Worker         m_JsonString = R"(
19*89c4ff92SAndroid Build Coastguard Worker         {
20*89c4ff92SAndroid Build Coastguard Worker             inputIds: [0],
21*89c4ff92SAndroid Build Coastguard Worker             outputIds: [2],
22*89c4ff92SAndroid Build Coastguard Worker             layers: [
23*89c4ff92SAndroid Build Coastguard Worker                 {
24*89c4ff92SAndroid Build Coastguard Worker                 layer_type: "InputLayer",
25*89c4ff92SAndroid Build Coastguard Worker                 layer: {
26*89c4ff92SAndroid Build Coastguard Worker                     base: {
27*89c4ff92SAndroid Build Coastguard Worker                         layerBindingId: 0,
28*89c4ff92SAndroid Build Coastguard Worker                         base: {
29*89c4ff92SAndroid Build Coastguard Worker                             index: 0,
30*89c4ff92SAndroid Build Coastguard Worker                             layerName: "InputLayer",
31*89c4ff92SAndroid Build Coastguard Worker                             layerType: "Input",
32*89c4ff92SAndroid Build Coastguard Worker                             inputSlots: [{
33*89c4ff92SAndroid Build Coastguard Worker                                 index: 0,
34*89c4ff92SAndroid Build Coastguard Worker                                 connection: { sourceLayerIndex:0, outputSlotIndex:0 },
35*89c4ff92SAndroid Build Coastguard Worker                                 }],
36*89c4ff92SAndroid Build Coastguard Worker                             outputSlots: [{
37*89c4ff92SAndroid Build Coastguard Worker                                 index: 0,
38*89c4ff92SAndroid Build Coastguard Worker                                 tensorInfo: {
39*89c4ff92SAndroid Build Coastguard Worker                                     dimensions: )" + shape + R"(,
40*89c4ff92SAndroid Build Coastguard Worker                                     dataType: ")" + dataType + R"(",
41*89c4ff92SAndroid Build Coastguard Worker                                     quantizationScale: 0.5,
42*89c4ff92SAndroid Build Coastguard Worker                                     quantizationOffset: 0
43*89c4ff92SAndroid Build Coastguard Worker                                     },
44*89c4ff92SAndroid Build Coastguard Worker                                 }]
45*89c4ff92SAndroid Build Coastguard Worker                             },
46*89c4ff92SAndroid Build Coastguard Worker                         }
47*89c4ff92SAndroid Build Coastguard Worker                     },
48*89c4ff92SAndroid Build Coastguard Worker                 },
49*89c4ff92SAndroid Build Coastguard Worker             {
50*89c4ff92SAndroid Build Coastguard Worker             layer_type: "LogSoftmaxLayer",
51*89c4ff92SAndroid Build Coastguard Worker             layer : {
52*89c4ff92SAndroid Build Coastguard Worker                 base: {
53*89c4ff92SAndroid Build Coastguard Worker                     index:1,
54*89c4ff92SAndroid Build Coastguard Worker                     layerName: "LogSoftmaxLayer",
55*89c4ff92SAndroid Build Coastguard Worker                     layerType: "LogSoftmax",
56*89c4ff92SAndroid Build Coastguard Worker                     inputSlots: [{
57*89c4ff92SAndroid Build Coastguard Worker                             index: 0,
58*89c4ff92SAndroid Build Coastguard Worker                             connection: { sourceLayerIndex:0, outputSlotIndex:0 },
59*89c4ff92SAndroid Build Coastguard Worker                         }],
60*89c4ff92SAndroid Build Coastguard Worker                     outputSlots: [{
61*89c4ff92SAndroid Build Coastguard Worker                         index: 0,
62*89c4ff92SAndroid Build Coastguard Worker                         tensorInfo: {
63*89c4ff92SAndroid Build Coastguard Worker                             dimensions: )" + shape + R"(,
64*89c4ff92SAndroid Build Coastguard Worker                             dataType: ")" + dataType + R"("
65*89c4ff92SAndroid Build Coastguard Worker                         },
66*89c4ff92SAndroid Build Coastguard Worker                         }],
67*89c4ff92SAndroid Build Coastguard Worker                     },
68*89c4ff92SAndroid Build Coastguard Worker                 descriptor: {
69*89c4ff92SAndroid Build Coastguard Worker                     beta: ")" + beta + R"(",
70*89c4ff92SAndroid Build Coastguard Worker                     axis: )" + axis + R"(
71*89c4ff92SAndroid Build Coastguard Worker                     },
72*89c4ff92SAndroid Build Coastguard Worker                 },
73*89c4ff92SAndroid Build Coastguard Worker             },
74*89c4ff92SAndroid Build Coastguard Worker             {
75*89c4ff92SAndroid Build Coastguard Worker             layer_type: "OutputLayer",
76*89c4ff92SAndroid Build Coastguard Worker             layer: {
77*89c4ff92SAndroid Build Coastguard Worker                 base:{
78*89c4ff92SAndroid Build Coastguard Worker                     layerBindingId: 0,
79*89c4ff92SAndroid Build Coastguard Worker                     base: {
80*89c4ff92SAndroid Build Coastguard Worker                         index: 2,
81*89c4ff92SAndroid Build Coastguard Worker                         layerName: "OutputLayer",
82*89c4ff92SAndroid Build Coastguard Worker                         layerType: "Output",
83*89c4ff92SAndroid Build Coastguard Worker                         inputSlots: [{
84*89c4ff92SAndroid Build Coastguard Worker                             index: 0,
85*89c4ff92SAndroid Build Coastguard Worker                             connection: { sourceLayerIndex:1, outputSlotIndex:0 },
86*89c4ff92SAndroid Build Coastguard Worker                         }],
87*89c4ff92SAndroid Build Coastguard Worker                         outputSlots: [ {
88*89c4ff92SAndroid Build Coastguard Worker                             index: 0,
89*89c4ff92SAndroid Build Coastguard Worker                             tensorInfo: {
90*89c4ff92SAndroid Build Coastguard Worker                                 dimensions: )" + shape + R"(,
91*89c4ff92SAndroid Build Coastguard Worker                                 dataType: ")" + dataType + R"("
92*89c4ff92SAndroid Build Coastguard Worker                             },
93*89c4ff92SAndroid Build Coastguard Worker                         }],
94*89c4ff92SAndroid Build Coastguard Worker                     }
95*89c4ff92SAndroid Build Coastguard Worker                 }},
96*89c4ff92SAndroid Build Coastguard Worker             }]
97*89c4ff92SAndroid Build Coastguard Worker         })";
98*89c4ff92SAndroid Build Coastguard Worker         SetupSingleInputSingleOutput("InputLayer", "OutputLayer");
99*89c4ff92SAndroid Build Coastguard Worker     }
100*89c4ff92SAndroid Build Coastguard Worker };
101*89c4ff92SAndroid Build Coastguard Worker 
102*89c4ff92SAndroid Build Coastguard Worker struct LogSoftmaxFloat32Fixture : LogSoftmaxFixture
103*89c4ff92SAndroid Build Coastguard Worker {
LogSoftmaxFloat32FixtureLogSoftmaxFloat32Fixture104*89c4ff92SAndroid Build Coastguard Worker     LogSoftmaxFloat32Fixture() :
105*89c4ff92SAndroid Build Coastguard Worker         LogSoftmaxFixture("[ 1, 1, 2, 4 ]", // inputShape
106*89c4ff92SAndroid Build Coastguard Worker                           "1.0",            // beta
107*89c4ff92SAndroid Build Coastguard Worker                           "3",              // axis
108*89c4ff92SAndroid Build Coastguard Worker                           "Float32")        // dataType
109*89c4ff92SAndroid Build Coastguard Worker     {}
110*89c4ff92SAndroid Build Coastguard Worker };
111*89c4ff92SAndroid Build Coastguard Worker 
112*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(LogSoftmaxFloat32Fixture, "LogSoftmaxFloat32")
113*89c4ff92SAndroid Build Coastguard Worker {
114*89c4ff92SAndroid Build Coastguard Worker     RunTest<4, armnn::DataType::Float32>(
115*89c4ff92SAndroid Build Coastguard Worker         0,
116*89c4ff92SAndroid Build Coastguard Worker         {
117*89c4ff92SAndroid Build Coastguard Worker             0.f, -6.f,  2.f, 4.f,
118*89c4ff92SAndroid Build Coastguard Worker             3.f, -2.f, 10.f, 1.f
119*89c4ff92SAndroid Build Coastguard Worker         },
120*89c4ff92SAndroid Build Coastguard Worker         {
121*89c4ff92SAndroid Build Coastguard Worker             -4.14297f, -10.14297f, -2.14297f, -0.14297f,
122*89c4ff92SAndroid Build Coastguard Worker             -7.00104f, -12.00104f, -0.00105f, -9.00104f
123*89c4ff92SAndroid Build Coastguard Worker         });
124*89c4ff92SAndroid Build Coastguard Worker }
125*89c4ff92SAndroid Build Coastguard Worker 
126*89c4ff92SAndroid Build Coastguard Worker }
127