xref: /aosp_15_r20/external/armnn/src/armnnTfLiteParser/test/L2Normalization.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "ParserFlatbuffersFixture.hpp"
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include <numeric>
9*89c4ff92SAndroid Build Coastguard Worker 
10*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("TensorflowLiteParser_L2Normalization")
11*89c4ff92SAndroid Build Coastguard Worker {
12*89c4ff92SAndroid Build Coastguard Worker struct L2NormalizationFixture : public ParserFlatbuffersFixture
13*89c4ff92SAndroid Build Coastguard Worker {
L2NormalizationFixtureL2NormalizationFixture14*89c4ff92SAndroid Build Coastguard Worker     explicit L2NormalizationFixture(const std::string & inputOutputShape)
15*89c4ff92SAndroid Build Coastguard Worker     {
16*89c4ff92SAndroid Build Coastguard Worker         m_JsonString = R"(
17*89c4ff92SAndroid Build Coastguard Worker             {
18*89c4ff92SAndroid Build Coastguard Worker                 "version": 3,
19*89c4ff92SAndroid Build Coastguard Worker                 "operator_codes": [ { "builtin_code": "L2_NORMALIZATION" } ],
20*89c4ff92SAndroid Build Coastguard Worker                 "subgraphs": [ {
21*89c4ff92SAndroid Build Coastguard Worker                     "tensors": [
22*89c4ff92SAndroid Build Coastguard Worker                         {
23*89c4ff92SAndroid Build Coastguard Worker                             "shape": )" + inputOutputShape + R"(,
24*89c4ff92SAndroid Build Coastguard Worker                             "type": "FLOAT32",
25*89c4ff92SAndroid Build Coastguard Worker                             "buffer": 0,
26*89c4ff92SAndroid Build Coastguard Worker                             "name": "inputTensor",
27*89c4ff92SAndroid Build Coastguard Worker                             "quantization": {
28*89c4ff92SAndroid Build Coastguard Worker                                 "min": [ 0.0 ],
29*89c4ff92SAndroid Build Coastguard Worker                                 "max": [ 255.0 ],
30*89c4ff92SAndroid Build Coastguard Worker                                 "scale": [ 1.0 ],
31*89c4ff92SAndroid Build Coastguard Worker                                 "zero_point": [ 0 ],
32*89c4ff92SAndroid Build Coastguard Worker                             }
33*89c4ff92SAndroid Build Coastguard Worker                         },
34*89c4ff92SAndroid Build Coastguard Worker                         {
35*89c4ff92SAndroid Build Coastguard Worker                             "shape": )" + inputOutputShape + R"(,
36*89c4ff92SAndroid Build Coastguard Worker                             "type": "FLOAT32",
37*89c4ff92SAndroid Build Coastguard Worker                             "buffer": 1,
38*89c4ff92SAndroid Build Coastguard Worker                             "name": "outputTensor",
39*89c4ff92SAndroid Build Coastguard Worker                             "quantization": {
40*89c4ff92SAndroid Build Coastguard Worker                                 "min": [ 0.0 ],
41*89c4ff92SAndroid Build Coastguard Worker                                 "max": [ 255.0 ],
42*89c4ff92SAndroid Build Coastguard Worker                                 "scale": [ 1.0 ],
43*89c4ff92SAndroid Build Coastguard Worker                                 "zero_point": [ 0 ],
44*89c4ff92SAndroid Build Coastguard Worker                             }
45*89c4ff92SAndroid Build Coastguard Worker                         }
46*89c4ff92SAndroid Build Coastguard Worker                     ],
47*89c4ff92SAndroid Build Coastguard Worker                     "inputs": [ 0 ],
48*89c4ff92SAndroid Build Coastguard Worker                     "outputs": [ 1 ],
49*89c4ff92SAndroid Build Coastguard Worker                     "operators": [
50*89c4ff92SAndroid Build Coastguard Worker                         {
51*89c4ff92SAndroid Build Coastguard Worker                             "opcode_index": 0,
52*89c4ff92SAndroid Build Coastguard Worker                             "inputs": [ 0 ],
53*89c4ff92SAndroid Build Coastguard Worker                             "outputs": [ 1 ],
54*89c4ff92SAndroid Build Coastguard Worker                             "custom_options_format": "FLEXBUFFERS"
55*89c4ff92SAndroid Build Coastguard Worker                         }
56*89c4ff92SAndroid Build Coastguard Worker                     ],
57*89c4ff92SAndroid Build Coastguard Worker                 } ],
58*89c4ff92SAndroid Build Coastguard Worker                 "buffers" : [
59*89c4ff92SAndroid Build Coastguard Worker                     { }
60*89c4ff92SAndroid Build Coastguard Worker                 ]
61*89c4ff92SAndroid Build Coastguard Worker             }
62*89c4ff92SAndroid Build Coastguard Worker         )";
63*89c4ff92SAndroid Build Coastguard Worker         Setup();
64*89c4ff92SAndroid Build Coastguard Worker     }
65*89c4ff92SAndroid Build Coastguard Worker };
66*89c4ff92SAndroid Build Coastguard Worker 
CalcL2Norm(std::initializer_list<float> elements)67*89c4ff92SAndroid Build Coastguard Worker float CalcL2Norm(std::initializer_list<float> elements)
68*89c4ff92SAndroid Build Coastguard Worker {
69*89c4ff92SAndroid Build Coastguard Worker     const float reduction = std::accumulate(elements.begin(), elements.end(), 0.0f,
70*89c4ff92SAndroid Build Coastguard Worker         [](float acc, float element) { return acc + element * element; });
71*89c4ff92SAndroid Build Coastguard Worker     const float eps = 1e-12f;
72*89c4ff92SAndroid Build Coastguard Worker     const float max = reduction < eps ? eps : reduction;
73*89c4ff92SAndroid Build Coastguard Worker     return sqrtf(max);
74*89c4ff92SAndroid Build Coastguard Worker }
75*89c4ff92SAndroid Build Coastguard Worker 
76*89c4ff92SAndroid Build Coastguard Worker struct L2NormalizationFixture4D : L2NormalizationFixture
77*89c4ff92SAndroid Build Coastguard Worker {
78*89c4ff92SAndroid Build Coastguard Worker     // TfLite uses NHWC shape
L2NormalizationFixture4DL2NormalizationFixture4D79*89c4ff92SAndroid Build Coastguard Worker     L2NormalizationFixture4D() : L2NormalizationFixture("[ 1, 1, 4, 3 ]") {}
80*89c4ff92SAndroid Build Coastguard Worker };
81*89c4ff92SAndroid Build Coastguard Worker 
82*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(L2NormalizationFixture4D, "ParseL2Normalization4D")
83*89c4ff92SAndroid Build Coastguard Worker {
84*89c4ff92SAndroid Build Coastguard Worker   RunTest<4, armnn::DataType::Float32>(
85*89c4ff92SAndroid Build Coastguard Worker       0,
86*89c4ff92SAndroid Build Coastguard Worker       {{"inputTensor", { 1.0f,  2.0f,  3.0f,
87*89c4ff92SAndroid Build Coastguard Worker                          4.0f,  5.0f,  6.0f,
88*89c4ff92SAndroid Build Coastguard Worker                          7.0f,  8.0f,  9.0f,
89*89c4ff92SAndroid Build Coastguard Worker                          10.0f, 11.0f, 12.0f }}},
90*89c4ff92SAndroid Build Coastguard Worker 
91*89c4ff92SAndroid Build Coastguard Worker       {{"outputTensor", { 1.0f  / CalcL2Norm({ 1.0f,  2.0f,  3.0f }),
92*89c4ff92SAndroid Build Coastguard Worker                           2.0f  / CalcL2Norm({ 1.0f,  2.0f,  3.0f }),
93*89c4ff92SAndroid Build Coastguard Worker                           3.0f  / CalcL2Norm({ 1.0f,  2.0f,  3.0f }),
94*89c4ff92SAndroid Build Coastguard Worker 
95*89c4ff92SAndroid Build Coastguard Worker                           4.0f  / CalcL2Norm({ 4.0f,  5.0f,  6.0f }),
96*89c4ff92SAndroid Build Coastguard Worker                           5.0f  / CalcL2Norm({ 4.0f,  5.0f,  6.0f }),
97*89c4ff92SAndroid Build Coastguard Worker                           6.0f  / CalcL2Norm({ 4.0f,  5.0f,  6.0f }),
98*89c4ff92SAndroid Build Coastguard Worker 
99*89c4ff92SAndroid Build Coastguard Worker                           7.0f  / CalcL2Norm({ 7.0f,  8.0f,  9.0f }),
100*89c4ff92SAndroid Build Coastguard Worker                           8.0f  / CalcL2Norm({ 7.0f,  8.0f,  9.0f }),
101*89c4ff92SAndroid Build Coastguard Worker                           9.0f  / CalcL2Norm({ 7.0f,  8.0f,  9.0f }),
102*89c4ff92SAndroid Build Coastguard Worker 
103*89c4ff92SAndroid Build Coastguard Worker                           10.0f / CalcL2Norm({ 10.0f, 11.0f, 12.0f }),
104*89c4ff92SAndroid Build Coastguard Worker                           11.0f / CalcL2Norm({ 10.0f, 11.0f, 12.0f }),
105*89c4ff92SAndroid Build Coastguard Worker                           12.0f / CalcL2Norm({ 10.0f, 11.0f, 12.0f }) }}});
106*89c4ff92SAndroid Build Coastguard Worker }
107*89c4ff92SAndroid Build Coastguard Worker 
108*89c4ff92SAndroid Build Coastguard Worker struct L2NormalizationSimpleFixture4D : L2NormalizationFixture
109*89c4ff92SAndroid Build Coastguard Worker {
L2NormalizationSimpleFixture4DL2NormalizationSimpleFixture4D110*89c4ff92SAndroid Build Coastguard Worker     L2NormalizationSimpleFixture4D() : L2NormalizationFixture("[ 1, 1, 1, 4 ]") {}
111*89c4ff92SAndroid Build Coastguard Worker };
112*89c4ff92SAndroid Build Coastguard Worker 
113*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(L2NormalizationSimpleFixture4D, "ParseL2NormalizationEps4D")
114*89c4ff92SAndroid Build Coastguard Worker {
115*89c4ff92SAndroid Build Coastguard Worker       RunTest<4, armnn::DataType::Float32>(
116*89c4ff92SAndroid Build Coastguard Worker       0,
117*89c4ff92SAndroid Build Coastguard Worker       {{"inputTensor", { 0.00000001f, 0.00000002f, 0.00000003f, 0.00000004f }}},
118*89c4ff92SAndroid Build Coastguard Worker 
119*89c4ff92SAndroid Build Coastguard Worker       {{"outputTensor", { 0.00000001f / CalcL2Norm({ 0.00000001f, 0.00000002f, 0.00000003f, 0.00000004f }),
120*89c4ff92SAndroid Build Coastguard Worker                           0.00000002f / CalcL2Norm({ 0.00000001f, 0.00000002f, 0.00000003f, 0.00000004f }),
121*89c4ff92SAndroid Build Coastguard Worker                           0.00000003f / CalcL2Norm({ 0.00000001f, 0.00000002f, 0.00000003f, 0.00000004f }),
122*89c4ff92SAndroid Build Coastguard Worker                           0.00000004f / CalcL2Norm({ 0.00000001f, 0.00000002f, 0.00000003f, 0.00000004f }) }}});
123*89c4ff92SAndroid Build Coastguard Worker }
124*89c4ff92SAndroid Build Coastguard Worker 
125*89c4ff92SAndroid Build Coastguard Worker }
126