xref: /aosp_15_r20/external/armnn/src/armnnOnnxParser/test/Clip.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2020 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "armnnOnnxParser/IOnnxParser.hpp"
7*89c4ff92SAndroid Build Coastguard Worker #include "ParserPrototxtFixture.hpp"
8*89c4ff92SAndroid Build Coastguard Worker 
9*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("OnnxParser_Clip")
10*89c4ff92SAndroid Build Coastguard Worker {
11*89c4ff92SAndroid Build Coastguard Worker struct ClipMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
12*89c4ff92SAndroid Build Coastguard Worker {
ClipMainFixtureClipMainFixture13*89c4ff92SAndroid Build Coastguard Worker     ClipMainFixture(std::string min, std::string max)
14*89c4ff92SAndroid Build Coastguard Worker     {
15*89c4ff92SAndroid Build Coastguard Worker         m_Prototext = R"(
16*89c4ff92SAndroid Build Coastguard Worker                    ir_version: 3
17*89c4ff92SAndroid Build Coastguard Worker                    producer_name:  "CNTK"
18*89c4ff92SAndroid Build Coastguard Worker                    producer_version:  "2.5.1"
19*89c4ff92SAndroid Build Coastguard Worker                    domain:  "ai.cntk"
20*89c4ff92SAndroid Build Coastguard Worker                    model_version: 1
21*89c4ff92SAndroid Build Coastguard Worker                    graph {
22*89c4ff92SAndroid Build Coastguard Worker                      name:  "CNTKGraph"
23*89c4ff92SAndroid Build Coastguard Worker                      input {
24*89c4ff92SAndroid Build Coastguard Worker                         name: "Input"
25*89c4ff92SAndroid Build Coastguard Worker                         type {
26*89c4ff92SAndroid Build Coastguard Worker                           tensor_type {
27*89c4ff92SAndroid Build Coastguard Worker                             elem_type: 1
28*89c4ff92SAndroid Build Coastguard Worker                             shape {
29*89c4ff92SAndroid Build Coastguard Worker                               dim {
30*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: 5
31*89c4ff92SAndroid Build Coastguard Worker                               }
32*89c4ff92SAndroid Build Coastguard Worker                             }
33*89c4ff92SAndroid Build Coastguard Worker                           }
34*89c4ff92SAndroid Build Coastguard Worker                         }
35*89c4ff92SAndroid Build Coastguard Worker                       }
36*89c4ff92SAndroid Build Coastguard Worker                      node {
37*89c4ff92SAndroid Build Coastguard Worker                          input: "Input"
38*89c4ff92SAndroid Build Coastguard Worker                          input:")" + min + R"("
39*89c4ff92SAndroid Build Coastguard Worker                          input:")" + max + R"("
40*89c4ff92SAndroid Build Coastguard Worker                          output: "Output"
41*89c4ff92SAndroid Build Coastguard Worker                          name: "ActivationLayer"
42*89c4ff92SAndroid Build Coastguard Worker                          op_type: "Clip"
43*89c4ff92SAndroid Build Coastguard Worker                     }
44*89c4ff92SAndroid Build Coastguard Worker                       output {
45*89c4ff92SAndroid Build Coastguard Worker                           name: "Output"
46*89c4ff92SAndroid Build Coastguard Worker                           type {
47*89c4ff92SAndroid Build Coastguard Worker                              tensor_type {
48*89c4ff92SAndroid Build Coastguard Worker                                elem_type: 1
49*89c4ff92SAndroid Build Coastguard Worker                                shape {
50*89c4ff92SAndroid Build Coastguard Worker                                    dim {
51*89c4ff92SAndroid Build Coastguard Worker                                        dim_value: 5
52*89c4ff92SAndroid Build Coastguard Worker                                    }
53*89c4ff92SAndroid Build Coastguard Worker                                }
54*89c4ff92SAndroid Build Coastguard Worker                             }
55*89c4ff92SAndroid Build Coastguard Worker                          }
56*89c4ff92SAndroid Build Coastguard Worker                       }
57*89c4ff92SAndroid Build Coastguard Worker                     }
58*89c4ff92SAndroid Build Coastguard Worker                    opset_import {
59*89c4ff92SAndroid Build Coastguard Worker                       version: 7
60*89c4ff92SAndroid Build Coastguard Worker                     })";
61*89c4ff92SAndroid Build Coastguard Worker         Setup();
62*89c4ff92SAndroid Build Coastguard Worker     }
63*89c4ff92SAndroid Build Coastguard Worker };
64*89c4ff92SAndroid Build Coastguard Worker 
65*89c4ff92SAndroid Build Coastguard Worker struct ClipAttributeFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
66*89c4ff92SAndroid Build Coastguard Worker {
ClipAttributeFixtureClipAttributeFixture67*89c4ff92SAndroid Build Coastguard Worker     ClipAttributeFixture(std::string min, std::string max)
68*89c4ff92SAndroid Build Coastguard Worker     {
69*89c4ff92SAndroid Build Coastguard Worker         m_Prototext = R"(
70*89c4ff92SAndroid Build Coastguard Worker                    ir_version: 3
71*89c4ff92SAndroid Build Coastguard Worker                    producer_name:  "CNTK"
72*89c4ff92SAndroid Build Coastguard Worker                    producer_version:  "2.5.1"
73*89c4ff92SAndroid Build Coastguard Worker                    domain:  "ai.cntk"
74*89c4ff92SAndroid Build Coastguard Worker                    model_version: 1
75*89c4ff92SAndroid Build Coastguard Worker                    graph {
76*89c4ff92SAndroid Build Coastguard Worker                      name:  "CNTKGraph"
77*89c4ff92SAndroid Build Coastguard Worker                      input {
78*89c4ff92SAndroid Build Coastguard Worker                         name: "Input"
79*89c4ff92SAndroid Build Coastguard Worker                         type {
80*89c4ff92SAndroid Build Coastguard Worker                           tensor_type {
81*89c4ff92SAndroid Build Coastguard Worker                             elem_type: 1
82*89c4ff92SAndroid Build Coastguard Worker                             shape {
83*89c4ff92SAndroid Build Coastguard Worker                               dim {
84*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: 5
85*89c4ff92SAndroid Build Coastguard Worker                               }
86*89c4ff92SAndroid Build Coastguard Worker                             }
87*89c4ff92SAndroid Build Coastguard Worker                           }
88*89c4ff92SAndroid Build Coastguard Worker                         }
89*89c4ff92SAndroid Build Coastguard Worker                       }
90*89c4ff92SAndroid Build Coastguard Worker                      node {
91*89c4ff92SAndroid Build Coastguard Worker                          input: "Input"
92*89c4ff92SAndroid Build Coastguard Worker                          output: "Output"
93*89c4ff92SAndroid Build Coastguard Worker                          name: "ActivationLayer"
94*89c4ff92SAndroid Build Coastguard Worker                          op_type: "Clip"
95*89c4ff92SAndroid Build Coastguard Worker                          attribute {
96*89c4ff92SAndroid Build Coastguard Worker                            name: "min"
97*89c4ff92SAndroid Build Coastguard Worker                            f:  )" + min + R"(
98*89c4ff92SAndroid Build Coastguard Worker                            type: FLOAT
99*89c4ff92SAndroid Build Coastguard Worker                          }
100*89c4ff92SAndroid Build Coastguard Worker                          attribute {
101*89c4ff92SAndroid Build Coastguard Worker                            name: "max"
102*89c4ff92SAndroid Build Coastguard Worker                            f:  )" + max + R"(
103*89c4ff92SAndroid Build Coastguard Worker                            type: FLOAT
104*89c4ff92SAndroid Build Coastguard Worker                          }
105*89c4ff92SAndroid Build Coastguard Worker                     }
106*89c4ff92SAndroid Build Coastguard Worker                       output {
107*89c4ff92SAndroid Build Coastguard Worker                           name: "Output"
108*89c4ff92SAndroid Build Coastguard Worker                           type {
109*89c4ff92SAndroid Build Coastguard Worker                              tensor_type {
110*89c4ff92SAndroid Build Coastguard Worker                                elem_type: 1
111*89c4ff92SAndroid Build Coastguard Worker                                shape {
112*89c4ff92SAndroid Build Coastguard Worker                                    dim {
113*89c4ff92SAndroid Build Coastguard Worker                                        dim_value: 5
114*89c4ff92SAndroid Build Coastguard Worker                                    }
115*89c4ff92SAndroid Build Coastguard Worker                                }
116*89c4ff92SAndroid Build Coastguard Worker                             }
117*89c4ff92SAndroid Build Coastguard Worker                          }
118*89c4ff92SAndroid Build Coastguard Worker                       }
119*89c4ff92SAndroid Build Coastguard Worker                     }
120*89c4ff92SAndroid Build Coastguard Worker                    opset_import {
121*89c4ff92SAndroid Build Coastguard Worker                       version: 7
122*89c4ff92SAndroid Build Coastguard Worker                     })";
123*89c4ff92SAndroid Build Coastguard Worker         Setup();
124*89c4ff92SAndroid Build Coastguard Worker     }
125*89c4ff92SAndroid Build Coastguard Worker };
126*89c4ff92SAndroid Build Coastguard Worker 
127*89c4ff92SAndroid Build Coastguard Worker struct ClipFixture : ClipMainFixture
128*89c4ff92SAndroid Build Coastguard Worker {
ClipFixtureClipFixture129*89c4ff92SAndroid Build Coastguard Worker     ClipFixture() : ClipMainFixture("2", "3.5") {}
130*89c4ff92SAndroid Build Coastguard Worker };
131*89c4ff92SAndroid Build Coastguard Worker 
132*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(ClipFixture, "ValidClipTest")
133*89c4ff92SAndroid Build Coastguard Worker {
134*89c4ff92SAndroid Build Coastguard Worker     RunTest<1>({{"Input",  { -1.5f, 1.25f, 3.5f, 8.0, 2.5}}},
135*89c4ff92SAndroid Build Coastguard Worker                {{ "Output", { 2.0f, 2.0f, 3.5f, 3.5, 2.5}}});
136*89c4ff92SAndroid Build Coastguard Worker }
137*89c4ff92SAndroid Build Coastguard Worker 
138*89c4ff92SAndroid Build Coastguard Worker struct ClipNoMaxInputFixture : ClipMainFixture
139*89c4ff92SAndroid Build Coastguard Worker {
ClipNoMaxInputFixtureClipNoMaxInputFixture140*89c4ff92SAndroid Build Coastguard Worker     ClipNoMaxInputFixture() : ClipMainFixture("0", std::string()) {}
141*89c4ff92SAndroid Build Coastguard Worker };
142*89c4ff92SAndroid Build Coastguard Worker 
143*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(ClipNoMaxInputFixture, "ValidNoMaxInputClipTest")
144*89c4ff92SAndroid Build Coastguard Worker {
145*89c4ff92SAndroid Build Coastguard Worker     RunTest<1>({{"Input",  { -1.5f, -5.25f, -0.5f, 8.0f, std::numeric_limits<float>::max() }}},
146*89c4ff92SAndroid Build Coastguard Worker                {{ "Output", { 0.0f, 0.0f, 0.0f, 8.0f, std::numeric_limits<float>::max() }}});
147*89c4ff92SAndroid Build Coastguard Worker }
148*89c4ff92SAndroid Build Coastguard Worker 
149*89c4ff92SAndroid Build Coastguard Worker struct ClipNoMinInputFixture : ClipMainFixture
150*89c4ff92SAndroid Build Coastguard Worker {
ClipNoMinInputFixtureClipNoMinInputFixture151*89c4ff92SAndroid Build Coastguard Worker     ClipNoMinInputFixture() : ClipMainFixture(std::string(), "6") {}
152*89c4ff92SAndroid Build Coastguard Worker };
153*89c4ff92SAndroid Build Coastguard Worker 
154*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(ClipNoMinInputFixture, "ValidNoMinInputClipTest")
155*89c4ff92SAndroid Build Coastguard Worker {
156*89c4ff92SAndroid Build Coastguard Worker     RunTest<1>({{"Input",   { std::numeric_limits<float>::lowest(), -5.25f, -0.5f, 8.0f, 200.0f }}},
157*89c4ff92SAndroid Build Coastguard Worker                {{ "Output", { std::numeric_limits<float>::lowest(), -5.25f, -0.5f, 6.0f, 6.0f }}});
158*89c4ff92SAndroid Build Coastguard Worker }
159*89c4ff92SAndroid Build Coastguard Worker 
160*89c4ff92SAndroid Build Coastguard Worker struct ClipNoInputFixture : ClipMainFixture
161*89c4ff92SAndroid Build Coastguard Worker {
ClipNoInputFixtureClipNoInputFixture162*89c4ff92SAndroid Build Coastguard Worker     ClipNoInputFixture() : ClipMainFixture(std::string(), std::string()) {}
163*89c4ff92SAndroid Build Coastguard Worker };
164*89c4ff92SAndroid Build Coastguard Worker 
165*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(ClipNoInputFixture, "ValidNoInputClipTest")
166*89c4ff92SAndroid Build Coastguard Worker {
167*89c4ff92SAndroid Build Coastguard Worker     RunTest<1>({{"Input",   { std::numeric_limits<float>::lowest(), -1.25f, 3.5f, 8.0f,
168*89c4ff92SAndroid Build Coastguard Worker                               std::numeric_limits<float>::max()}}},
169*89c4ff92SAndroid Build Coastguard Worker                {{ "Output", { std::numeric_limits<float>::lowest(), -1.25f, 3.5f, 8.0f,
170*89c4ff92SAndroid Build Coastguard Worker                               std::numeric_limits<float>::max()}}});
171*89c4ff92SAndroid Build Coastguard Worker }
172*89c4ff92SAndroid Build Coastguard Worker 
173*89c4ff92SAndroid Build Coastguard Worker struct ClipMinMaxAttributeFixture : ClipAttributeFixture
174*89c4ff92SAndroid Build Coastguard Worker {
ClipMinMaxAttributeFixtureClipMinMaxAttributeFixture175*89c4ff92SAndroid Build Coastguard Worker     ClipMinMaxAttributeFixture() : ClipAttributeFixture("2", "3.5") {}
176*89c4ff92SAndroid Build Coastguard Worker };
177*89c4ff92SAndroid Build Coastguard Worker 
178*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(ClipMinMaxAttributeFixture, "ValidClipAttributeTest")
179*89c4ff92SAndroid Build Coastguard Worker {
180*89c4ff92SAndroid Build Coastguard Worker     RunTest<1>({{ "Input",  { -1.5f, 1.25f, 3.5f, 8.0, 2.5}}},
181*89c4ff92SAndroid Build Coastguard Worker                {{ "Output", { 2.0f, 2.0f, 3.5f, 3.5, 2.5}}});
182*89c4ff92SAndroid Build Coastguard Worker }
183*89c4ff92SAndroid Build Coastguard Worker 
184*89c4ff92SAndroid Build Coastguard Worker }
185