1 // 2 // Copyright © 2021 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #include "armnnOnnxParser/IOnnxParser.hpp" 7 #include "ParserPrototxtFixture.hpp" 8 #include "OnnxParserTestUtils.hpp" 9 10 TEST_SUITE("OnnxParser_Unsqueeze") 11 { 12 13 struct UnsqueezeFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser> 14 { UnsqueezeFixtureUnsqueezeFixture15 UnsqueezeFixture(const std::vector<int>& axes, 16 const std::vector<int>& inputShape, 17 const std::vector<int>& outputShape) 18 { 19 m_Prototext = R"( 20 ir_version: 8 21 producer_name: "onnx-example" 22 graph { 23 node { 24 input: "Input" 25 output: "Output" 26 op_type: "Unsqueeze" 27 )" + armnnUtils::ConstructIntsAttribute("axes", axes) + R"( 28 } 29 name: "test-model" 30 input { 31 name: "Input" 32 type { 33 tensor_type { 34 elem_type: 1 35 shape { 36 )" + armnnUtils::ConstructTensorShapeString(inputShape) + R"( 37 } 38 } 39 } 40 } 41 output { 42 name: "Output" 43 type { 44 tensor_type { 45 elem_type: 1 46 shape { 47 )" + armnnUtils::ConstructTensorShapeString(outputShape) + R"( 48 } 49 } 50 } 51 } 52 })"; 53 } 54 }; 55 56 struct UnsqueezeSingleAxesFixture : UnsqueezeFixture 57 { UnsqueezeSingleAxesFixtureUnsqueezeSingleAxesFixture58 UnsqueezeSingleAxesFixture() : UnsqueezeFixture({ 0 }, { 2, 3 }, { 1, 2, 3 }) 59 { 60 Setup(); 61 } 62 }; 63 64 struct UnsqueezeMultiAxesFixture : UnsqueezeFixture 65 { UnsqueezeMultiAxesFixtureUnsqueezeMultiAxesFixture66 UnsqueezeMultiAxesFixture() : UnsqueezeFixture({ 1, 3 }, { 3, 2, 5 }, { 3, 1, 2, 1, 5 }) 67 { 68 Setup(); 69 } 70 }; 71 72 struct UnsqueezeUnsortedAxesFixture : UnsqueezeFixture 73 { UnsqueezeUnsortedAxesFixtureUnsqueezeUnsortedAxesFixture74 UnsqueezeUnsortedAxesFixture() : UnsqueezeFixture({ 3, 0, 1 }, { 2, 5 }, { 1, 1, 2, 1, 5 }) 75 { 76 Setup(); 77 } 78 }; 79 80 struct UnsqueezeScalarFixture : UnsqueezeFixture 81 { UnsqueezeScalarFixtureUnsqueezeScalarFixture82 UnsqueezeScalarFixture() : UnsqueezeFixture({ 0 }, { }, { 1 }) 83 { 84 Setup(); 85 } 86 }; 87 88 TEST_CASE_FIXTURE(UnsqueezeSingleAxesFixture, "UnsqueezeSingleAxesTest") 89 { 90 RunTest<3, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }}}, 91 {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }}}); 92 } 93 94 TEST_CASE_FIXTURE(UnsqueezeMultiAxesFixture, "UnsqueezeMultiAxesTest") 95 { 96 RunTest<5, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 97 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 98 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 99 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 100 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 101 26.0f, 27.0f, 28.0f, 29.0f, 30.0f }}}, 102 {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 103 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 104 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 105 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 106 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 107 26.0f, 27.0f, 28.0f, 29.0f, 30.0f }}}); 108 } 109 110 TEST_CASE_FIXTURE(UnsqueezeUnsortedAxesFixture, "UnsqueezeUnsortedAxesTest") 111 { 112 RunTest<5, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 113 6.0f, 7.0f, 8.0f, 9.0f, 10.0f }}}, 114 {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 115 6.0f, 7.0f, 8.0f, 9.0f, 10.0f }}}); 116 } 117 118 TEST_CASE_FIXTURE(UnsqueezeScalarFixture, "UnsqueezeScalarTest") 119 { 120 RunTest<1, float>({{"Input", { 1.0f }}}, 121 {{"Output", { 1.0f }}}); 122 } 123 124 struct UnsqueezeInputAxesFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser> 125 { UnsqueezeInputAxesFixtureUnsqueezeInputAxesFixture126 UnsqueezeInputAxesFixture() 127 { 128 m_Prototext = R"( 129 ir_version: 8 130 producer_name: "onnx-example" 131 graph { 132 node { 133 input: "Input" 134 input: "Axes" 135 output: "Output" 136 op_type: "Unsqueeze" 137 } 138 initializer { 139 dims: 2 140 data_type: 7 141 int64_data: 0 142 int64_data: 3 143 name: "Axes" 144 } 145 name: "test-model" 146 input { 147 name: "Input" 148 type { 149 tensor_type { 150 elem_type: 1 151 shape { 152 dim { 153 dim_value: 3 154 } 155 dim { 156 dim_value: 2 157 } 158 dim { 159 dim_value: 5 160 } 161 } 162 } 163 } 164 } 165 output { 166 name: "Output" 167 type { 168 tensor_type { 169 elem_type: 1 170 shape { 171 dim { 172 dim_value: 1 173 } 174 dim { 175 dim_value: 3 176 } 177 dim { 178 dim_value: 2 179 } 180 dim { 181 dim_value: 1 182 } 183 dim { 184 dim_value: 5 185 } 186 } 187 } 188 } 189 } 190 })"; 191 Setup(); 192 } 193 }; 194 195 TEST_CASE_FIXTURE(UnsqueezeInputAxesFixture, "UnsqueezeInputAxesTest") 196 { 197 RunTest<5, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 198 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 199 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 200 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 201 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 202 26.0f, 27.0f, 28.0f, 29.0f, 30.0f }}}, 203 {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 204 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 205 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 206 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 207 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 208 26.0f, 27.0f, 28.0f, 29.0f, 30.0f }}}); 209 } 210 211 } 212