xref: /aosp_15_r20/external/armnn/src/armnnDeserializer/test/DeserializeLogSoftmax.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2019 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "ParserFlatbuffersSerializeFixture.hpp"
7 #include <armnnDeserializer/IDeserializer.hpp>
8 
9 TEST_SUITE("Deserializer_LogSoftmax")
10 {
11 struct LogSoftmaxFixture : public ParserFlatbuffersSerializeFixture
12 {
LogSoftmaxFixtureLogSoftmaxFixture13     explicit LogSoftmaxFixture(const std::string &shape,
14                                const std::string &beta,
15                                const std::string &axis,
16                                const std::string &dataType)
17     {
18         m_JsonString = R"(
19         {
20             inputIds: [0],
21             outputIds: [2],
22             layers: [
23                 {
24                 layer_type: "InputLayer",
25                 layer: {
26                     base: {
27                         layerBindingId: 0,
28                         base: {
29                             index: 0,
30                             layerName: "InputLayer",
31                             layerType: "Input",
32                             inputSlots: [{
33                                 index: 0,
34                                 connection: { sourceLayerIndex:0, outputSlotIndex:0 },
35                                 }],
36                             outputSlots: [{
37                                 index: 0,
38                                 tensorInfo: {
39                                     dimensions: )" + shape + R"(,
40                                     dataType: ")" + dataType + R"(",
41                                     quantizationScale: 0.5,
42                                     quantizationOffset: 0
43                                     },
44                                 }]
45                             },
46                         }
47                     },
48                 },
49             {
50             layer_type: "LogSoftmaxLayer",
51             layer : {
52                 base: {
53                     index:1,
54                     layerName: "LogSoftmaxLayer",
55                     layerType: "LogSoftmax",
56                     inputSlots: [{
57                             index: 0,
58                             connection: { sourceLayerIndex:0, outputSlotIndex:0 },
59                         }],
60                     outputSlots: [{
61                         index: 0,
62                         tensorInfo: {
63                             dimensions: )" + shape + R"(,
64                             dataType: ")" + dataType + R"("
65                         },
66                         }],
67                     },
68                 descriptor: {
69                     beta: ")" + beta + R"(",
70                     axis: )" + axis + R"(
71                     },
72                 },
73             },
74             {
75             layer_type: "OutputLayer",
76             layer: {
77                 base:{
78                     layerBindingId: 0,
79                     base: {
80                         index: 2,
81                         layerName: "OutputLayer",
82                         layerType: "Output",
83                         inputSlots: [{
84                             index: 0,
85                             connection: { sourceLayerIndex:1, outputSlotIndex:0 },
86                         }],
87                         outputSlots: [ {
88                             index: 0,
89                             tensorInfo: {
90                                 dimensions: )" + shape + R"(,
91                                 dataType: ")" + dataType + R"("
92                             },
93                         }],
94                     }
95                 }},
96             }]
97         })";
98         SetupSingleInputSingleOutput("InputLayer", "OutputLayer");
99     }
100 };
101 
102 struct LogSoftmaxFloat32Fixture : LogSoftmaxFixture
103 {
LogSoftmaxFloat32FixtureLogSoftmaxFloat32Fixture104     LogSoftmaxFloat32Fixture() :
105         LogSoftmaxFixture("[ 1, 1, 2, 4 ]", // inputShape
106                           "1.0",            // beta
107                           "3",              // axis
108                           "Float32")        // dataType
109     {}
110 };
111 
112 TEST_CASE_FIXTURE(LogSoftmaxFloat32Fixture, "LogSoftmaxFloat32")
113 {
114     RunTest<4, armnn::DataType::Float32>(
115         0,
116         {
117             0.f, -6.f,  2.f, 4.f,
118             3.f, -2.f, 10.f, 1.f
119         },
120         {
121             -4.14297f, -10.14297f, -2.14297f, -0.14297f,
122             -7.00104f, -12.00104f, -0.00105f, -9.00104f
123         });
124 }
125 
126 }
127