xref: /aosp_15_r20/external/armnn/src/armnnTfLiteParser/test/Activations.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "ParserFlatbuffersFixture.hpp"
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
9*89c4ff92SAndroid Build Coastguard Worker 
10*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("TensorflowLiteParser_Activations")
11*89c4ff92SAndroid Build Coastguard Worker {
12*89c4ff92SAndroid Build Coastguard Worker struct ActivationFixture : ParserFlatbuffersFixture
13*89c4ff92SAndroid Build Coastguard Worker {
14*89c4ff92SAndroid Build Coastguard Worker 
ActivationFixtureActivationFixture15*89c4ff92SAndroid Build Coastguard Worker     explicit ActivationFixture(std::string activationFunction, std::string dataType)
16*89c4ff92SAndroid Build Coastguard Worker     {
17*89c4ff92SAndroid Build Coastguard Worker         m_JsonString = R"(
18*89c4ff92SAndroid Build Coastguard Worker             {
19*89c4ff92SAndroid Build Coastguard Worker                 "version": 3,
20*89c4ff92SAndroid Build Coastguard Worker                 "operator_codes": [ { "builtin_code": )" + activationFunction + R"( } ],
21*89c4ff92SAndroid Build Coastguard Worker                 "subgraphs": [ {
22*89c4ff92SAndroid Build Coastguard Worker                     "tensors": [
23*89c4ff92SAndroid Build Coastguard Worker                         {
24*89c4ff92SAndroid Build Coastguard Worker                             "shape": [ 1, 7 ],
25*89c4ff92SAndroid Build Coastguard Worker                             "type": )" + dataType + R"(,
26*89c4ff92SAndroid Build Coastguard Worker                             "buffer": 0,
27*89c4ff92SAndroid Build Coastguard Worker                             "name": "inputTensor",
28*89c4ff92SAndroid Build Coastguard Worker                             "quantization": {
29*89c4ff92SAndroid Build Coastguard Worker                                 "min": [ 0.0 ],
30*89c4ff92SAndroid Build Coastguard Worker                                 "max": [ 255.0 ],
31*89c4ff92SAndroid Build Coastguard Worker                                 "scale": [ 1.0 ],
32*89c4ff92SAndroid Build Coastguard Worker                                 "zero_point": [ 0 ],
33*89c4ff92SAndroid Build Coastguard Worker                             }
34*89c4ff92SAndroid Build Coastguard Worker                         },
35*89c4ff92SAndroid Build Coastguard Worker                         {
36*89c4ff92SAndroid Build Coastguard Worker                             "shape": [ 1, 7 ],
37*89c4ff92SAndroid Build Coastguard Worker                             "type": )" + dataType + R"(,
38*89c4ff92SAndroid Build Coastguard Worker                             "buffer": 1,
39*89c4ff92SAndroid Build Coastguard Worker                             "name": "outputTensor",
40*89c4ff92SAndroid Build Coastguard Worker                             "quantization": {
41*89c4ff92SAndroid Build Coastguard Worker                                 "min": [ 0.0 ],
42*89c4ff92SAndroid Build Coastguard Worker                                 "max": [ 255.0 ],
43*89c4ff92SAndroid Build Coastguard Worker                                 "scale": [ 1.0 ],
44*89c4ff92SAndroid Build Coastguard Worker                                 "zero_point": [ 0 ],
45*89c4ff92SAndroid Build Coastguard Worker                             }
46*89c4ff92SAndroid Build Coastguard Worker                         }
47*89c4ff92SAndroid Build Coastguard Worker                     ],
48*89c4ff92SAndroid Build Coastguard Worker                     "inputs": [ 0 ],
49*89c4ff92SAndroid Build Coastguard Worker                     "outputs": [ 1 ],
50*89c4ff92SAndroid Build Coastguard Worker                     "operators": [
51*89c4ff92SAndroid Build Coastguard Worker                         {
52*89c4ff92SAndroid Build Coastguard Worker                           "opcode_index": 0,
53*89c4ff92SAndroid Build Coastguard Worker                           "inputs": [ 0 ],
54*89c4ff92SAndroid Build Coastguard Worker                           "outputs": [ 1 ],
55*89c4ff92SAndroid Build Coastguard Worker                           "custom_options_format": "FLEXBUFFERS"
56*89c4ff92SAndroid Build Coastguard Worker                         }
57*89c4ff92SAndroid Build Coastguard Worker                     ],
58*89c4ff92SAndroid Build Coastguard Worker                 } ],
59*89c4ff92SAndroid Build Coastguard Worker                 "buffers" : [ {}, {} ]
60*89c4ff92SAndroid Build Coastguard Worker             }
61*89c4ff92SAndroid Build Coastguard Worker         )";
62*89c4ff92SAndroid Build Coastguard Worker         SetupSingleInputSingleOutput("inputTensor", "outputTensor");
63*89c4ff92SAndroid Build Coastguard Worker     }
64*89c4ff92SAndroid Build Coastguard Worker 
65*89c4ff92SAndroid Build Coastguard Worker };
66*89c4ff92SAndroid Build Coastguard Worker 
67*89c4ff92SAndroid Build Coastguard Worker struct ReLuFixture : ActivationFixture
68*89c4ff92SAndroid Build Coastguard Worker {
ReLuFixtureReLuFixture69*89c4ff92SAndroid Build Coastguard Worker     ReLuFixture() : ActivationFixture("RELU", "FLOAT32") {}
70*89c4ff92SAndroid Build Coastguard Worker };
71*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(ReLuFixture, "ParseReLu")
72*89c4ff92SAndroid Build Coastguard Worker {
73*89c4ff92SAndroid Build Coastguard Worker     RunTest<2, armnn::DataType::Float32>(0, { -1.0f, -0.5f, 1.25f, -3.0f, 0.0f, 0.5f, -0.75f },
74*89c4ff92SAndroid Build Coastguard Worker                                          { 0.0f, 0.0f, 1.25f, 0.0f, 0.0f, 0.5f, 0.0f });
75*89c4ff92SAndroid Build Coastguard Worker }
76*89c4ff92SAndroid Build Coastguard Worker 
77*89c4ff92SAndroid Build Coastguard Worker struct ReLu6Fixture : ActivationFixture
78*89c4ff92SAndroid Build Coastguard Worker {
ReLu6FixtureReLu6Fixture79*89c4ff92SAndroid Build Coastguard Worker     ReLu6Fixture() : ActivationFixture("RELU6", "FLOAT32") {}
80*89c4ff92SAndroid Build Coastguard Worker };
81*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(ReLu6Fixture, "ParseReLu6")
82*89c4ff92SAndroid Build Coastguard Worker {
83*89c4ff92SAndroid Build Coastguard Worker     RunTest<2, armnn::DataType::Float32>(0, { -1.0f, -0.5f, 7.25f, -3.0f, 0.0f, 0.5f, -0.75f },
84*89c4ff92SAndroid Build Coastguard Worker                                          { 0.0f, 0.0f, 6.0f, 0.0f, 0.0f, 0.5f, 0.0f });
85*89c4ff92SAndroid Build Coastguard Worker }
86*89c4ff92SAndroid Build Coastguard Worker 
87*89c4ff92SAndroid Build Coastguard Worker struct SigmoidFixture : ActivationFixture
88*89c4ff92SAndroid Build Coastguard Worker {
SigmoidFixtureSigmoidFixture89*89c4ff92SAndroid Build Coastguard Worker     SigmoidFixture() : ActivationFixture("LOGISTIC", "FLOAT32") {}
90*89c4ff92SAndroid Build Coastguard Worker };
91*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(SigmoidFixture, "ParseLogistic")
92*89c4ff92SAndroid Build Coastguard Worker {
93*89c4ff92SAndroid Build Coastguard Worker     RunTest<2, armnn::DataType::Float32>(0, { -1.0f,     -0.5f,      4.0f,       -4.0f,  0.0f,      0.5f,     -0.75f },
94*89c4ff92SAndroid Build Coastguard Worker                                          {0.268941f, 0.377541f, 0.982013f,  0.0179862f,  0.5f, 0.622459f,  0.320821f });
95*89c4ff92SAndroid Build Coastguard Worker }
96*89c4ff92SAndroid Build Coastguard Worker 
97*89c4ff92SAndroid Build Coastguard Worker struct TanHFixture : ActivationFixture
98*89c4ff92SAndroid Build Coastguard Worker {
TanHFixtureTanHFixture99*89c4ff92SAndroid Build Coastguard Worker     TanHFixture() : ActivationFixture("TANH", "FLOAT32") {}
100*89c4ff92SAndroid Build Coastguard Worker };
101*89c4ff92SAndroid Build Coastguard Worker 
102*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(TanHFixture, "ParseTanH")
103*89c4ff92SAndroid Build Coastguard Worker {
104*89c4ff92SAndroid Build Coastguard Worker     RunTest<2, armnn::DataType::Float32>(0,
105*89c4ff92SAndroid Build Coastguard Worker         { -0.1f,       -0.2f,         -0.3f,       -0.4f,    0.1f,         0.2f,              0.3f },
106*89c4ff92SAndroid Build Coastguard Worker         { -0.09966799f, -0.19737528f, -0.29131261f, -0.379949f, 0.09966799f, 0.19737528f, 0.29131261f });
107*89c4ff92SAndroid Build Coastguard Worker }
108*89c4ff92SAndroid Build Coastguard Worker 
109*89c4ff92SAndroid Build Coastguard Worker struct EluFixture : ActivationFixture
110*89c4ff92SAndroid Build Coastguard Worker {
EluFixtureEluFixture111*89c4ff92SAndroid Build Coastguard Worker     EluFixture() : ActivationFixture("ELU", "FLOAT32") {}
112*89c4ff92SAndroid Build Coastguard Worker };
113*89c4ff92SAndroid Build Coastguard Worker 
114*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(EluFixture, "ParseElu")
115*89c4ff92SAndroid Build Coastguard Worker {
116*89c4ff92SAndroid Build Coastguard Worker     RunTest<2, armnn::DataType::Float32>(0,
117*89c4ff92SAndroid Build Coastguard Worker                                          { -2.0f,           -1.0f,           -0.0f, 0.0f, 1.0f, 2.0f, 3.0f },
118*89c4ff92SAndroid Build Coastguard Worker                                          { -0.86466471676f, -0.63212055882f, -0.0f, 0.0f, 1.0f, 2.0f, 3.0f });
119*89c4ff92SAndroid Build Coastguard Worker }
120*89c4ff92SAndroid Build Coastguard Worker 
121*89c4ff92SAndroid Build Coastguard Worker struct HardSwishFixture : ActivationFixture
122*89c4ff92SAndroid Build Coastguard Worker {
HardSwishFixtureHardSwishFixture123*89c4ff92SAndroid Build Coastguard Worker     HardSwishFixture() : ActivationFixture("HARD_SWISH", "FLOAT32") {}
124*89c4ff92SAndroid Build Coastguard Worker };
125*89c4ff92SAndroid Build Coastguard Worker 
126*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(HardSwishFixture, "ParseHardSwish")
127*89c4ff92SAndroid Build Coastguard Worker {
128*89c4ff92SAndroid Build Coastguard Worker     RunTest<2, armnn::DataType::Float32>(0,
129*89c4ff92SAndroid Build Coastguard Worker                                          { -4.0f, -3.0f,        -2.9f,  1.2f,        2.2f, 3.0f, 4.0f },
130*89c4ff92SAndroid Build Coastguard Worker                                          { -0.0f, -0.0f, -0.04833334f, 0.84f, 1.90666667f, 3.0f, 4.0f });
131*89c4ff92SAndroid Build Coastguard Worker }
132*89c4ff92SAndroid Build Coastguard Worker 
133*89c4ff92SAndroid Build Coastguard Worker }
134