xref: /aosp_15_r20/external/armnn/src/armnnTfLiteParser/test/L2Normalization.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "ParserFlatbuffersFixture.hpp"
7 
8 #include <numeric>
9 
10 TEST_SUITE("TensorflowLiteParser_L2Normalization")
11 {
12 struct L2NormalizationFixture : public ParserFlatbuffersFixture
13 {
L2NormalizationFixtureL2NormalizationFixture14     explicit L2NormalizationFixture(const std::string & inputOutputShape)
15     {
16         m_JsonString = R"(
17             {
18                 "version": 3,
19                 "operator_codes": [ { "builtin_code": "L2_NORMALIZATION" } ],
20                 "subgraphs": [ {
21                     "tensors": [
22                         {
23                             "shape": )" + inputOutputShape + R"(,
24                             "type": "FLOAT32",
25                             "buffer": 0,
26                             "name": "inputTensor",
27                             "quantization": {
28                                 "min": [ 0.0 ],
29                                 "max": [ 255.0 ],
30                                 "scale": [ 1.0 ],
31                                 "zero_point": [ 0 ],
32                             }
33                         },
34                         {
35                             "shape": )" + inputOutputShape + R"(,
36                             "type": "FLOAT32",
37                             "buffer": 1,
38                             "name": "outputTensor",
39                             "quantization": {
40                                 "min": [ 0.0 ],
41                                 "max": [ 255.0 ],
42                                 "scale": [ 1.0 ],
43                                 "zero_point": [ 0 ],
44                             }
45                         }
46                     ],
47                     "inputs": [ 0 ],
48                     "outputs": [ 1 ],
49                     "operators": [
50                         {
51                             "opcode_index": 0,
52                             "inputs": [ 0 ],
53                             "outputs": [ 1 ],
54                             "custom_options_format": "FLEXBUFFERS"
55                         }
56                     ],
57                 } ],
58                 "buffers" : [
59                     { }
60                 ]
61             }
62         )";
63         Setup();
64     }
65 };
66 
CalcL2Norm(std::initializer_list<float> elements)67 float CalcL2Norm(std::initializer_list<float> elements)
68 {
69     const float reduction = std::accumulate(elements.begin(), elements.end(), 0.0f,
70         [](float acc, float element) { return acc + element * element; });
71     const float eps = 1e-12f;
72     const float max = reduction < eps ? eps : reduction;
73     return sqrtf(max);
74 }
75 
76 struct L2NormalizationFixture4D : L2NormalizationFixture
77 {
78     // TfLite uses NHWC shape
L2NormalizationFixture4DL2NormalizationFixture4D79     L2NormalizationFixture4D() : L2NormalizationFixture("[ 1, 1, 4, 3 ]") {}
80 };
81 
82 TEST_CASE_FIXTURE(L2NormalizationFixture4D, "ParseL2Normalization4D")
83 {
84   RunTest<4, armnn::DataType::Float32>(
85       0,
86       {{"inputTensor", { 1.0f,  2.0f,  3.0f,
87                          4.0f,  5.0f,  6.0f,
88                          7.0f,  8.0f,  9.0f,
89                          10.0f, 11.0f, 12.0f }}},
90 
91       {{"outputTensor", { 1.0f  / CalcL2Norm({ 1.0f,  2.0f,  3.0f }),
92                           2.0f  / CalcL2Norm({ 1.0f,  2.0f,  3.0f }),
93                           3.0f  / CalcL2Norm({ 1.0f,  2.0f,  3.0f }),
94 
95                           4.0f  / CalcL2Norm({ 4.0f,  5.0f,  6.0f }),
96                           5.0f  / CalcL2Norm({ 4.0f,  5.0f,  6.0f }),
97                           6.0f  / CalcL2Norm({ 4.0f,  5.0f,  6.0f }),
98 
99                           7.0f  / CalcL2Norm({ 7.0f,  8.0f,  9.0f }),
100                           8.0f  / CalcL2Norm({ 7.0f,  8.0f,  9.0f }),
101                           9.0f  / CalcL2Norm({ 7.0f,  8.0f,  9.0f }),
102 
103                           10.0f / CalcL2Norm({ 10.0f, 11.0f, 12.0f }),
104                           11.0f / CalcL2Norm({ 10.0f, 11.0f, 12.0f }),
105                           12.0f / CalcL2Norm({ 10.0f, 11.0f, 12.0f }) }}});
106 }
107 
108 struct L2NormalizationSimpleFixture4D : L2NormalizationFixture
109 {
L2NormalizationSimpleFixture4DL2NormalizationSimpleFixture4D110     L2NormalizationSimpleFixture4D() : L2NormalizationFixture("[ 1, 1, 1, 4 ]") {}
111 };
112 
113 TEST_CASE_FIXTURE(L2NormalizationSimpleFixture4D, "ParseL2NormalizationEps4D")
114 {
115       RunTest<4, armnn::DataType::Float32>(
116       0,
117       {{"inputTensor", { 0.00000001f, 0.00000002f, 0.00000003f, 0.00000004f }}},
118 
119       {{"outputTensor", { 0.00000001f / CalcL2Norm({ 0.00000001f, 0.00000002f, 0.00000003f, 0.00000004f }),
120                           0.00000002f / CalcL2Norm({ 0.00000001f, 0.00000002f, 0.00000003f, 0.00000004f }),
121                           0.00000003f / CalcL2Norm({ 0.00000001f, 0.00000002f, 0.00000003f, 0.00000004f }),
122                           0.00000004f / CalcL2Norm({ 0.00000001f, 0.00000002f, 0.00000003f, 0.00000004f }) }}});
123 }
124 
125 }
126