1 // 2 // Copyright © 2021 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #include "armnnOnnxParser/IOnnxParser.hpp" 7 #include "ParserPrototxtFixture.hpp" 8 #include "OnnxParserTestUtils.hpp" 9 10 TEST_SUITE("OnnxParser_Gather") 11 { 12 13 struct GatherMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser> 14 { GatherMainFixtureGatherMainFixture15 GatherMainFixture(const std::vector<int>& indicesShape, 16 const std::vector<int>& indices, 17 const std::vector<int>& inputShape, 18 const std::vector<int>& outputShape) 19 { 20 m_Prototext = R"( 21 ir_version: 8 22 producer_name: "onnx-example" 23 graph { 24 node { 25 output: "indices" 26 op_type: "Constant" 27 attribute { 28 name: "value" 29 t { 30 data_type: 7 31 )" + ConstructIndicesString(indicesShape, indices) + R"( 32 name: "value" 33 } 34 type: TENSOR 35 } 36 } 37 node { 38 input: "input" 39 input: "indices" 40 output: "output" 41 op_type: "Gather" 42 attribute { 43 name: "axis" 44 i: 0 45 type: INT 46 } 47 } 48 name: "gather-model" 49 input { 50 name: "input" 51 type { 52 tensor_type { 53 elem_type: 1 54 shape { 55 )" + armnnUtils::ConstructTensorShapeString(inputShape) + R"( 56 } 57 } 58 } 59 } 60 output { 61 name: "output" 62 type { 63 tensor_type { 64 elem_type: 1 65 shape { 66 )" + armnnUtils::ConstructTensorShapeString(outputShape) + R"( 67 } 68 } 69 } 70 } 71 })"; 72 } ConstructIndicesStringGatherMainFixture73 std::string ConstructIndicesString(const std::vector<int>& indicesShape, const std::vector<int>& indices) 74 { 75 std::string shapeStr; 76 for (int i : indicesShape) 77 { 78 shapeStr = fmt::format(" {} dims: {}", shapeStr, i); 79 } 80 for (int i : indices) 81 { 82 shapeStr = fmt::format(" {} int64_data: {}", shapeStr, i); 83 } 84 return shapeStr; 85 } 86 }; 87 88 struct GatherScalarFixture : GatherMainFixture 89 { GatherScalarFixtureGatherScalarFixture90 GatherScalarFixture() : GatherMainFixture({ }, { 0 }, { 8 }, { }) 91 { 92 Setup(); 93 } 94 }; 95 96 struct Gather1dFixture : GatherMainFixture 97 { Gather1dFixtureGather1dFixture98 Gather1dFixture() : GatherMainFixture({ 4 }, { 0, 2, 1, 5 }, { 8 }, { 4 }) 99 { 100 Setup(); 101 } 102 }; 103 104 struct Gather2dFixture : GatherMainFixture 105 { Gather2dFixtureGather2dFixture106 Gather2dFixture() : GatherMainFixture({ 3 }, { 1, 3, 4 }, { 5, 2 }, { 3, 2 }) 107 { 108 Setup(); 109 } 110 }; 111 112 struct Gather3dMultiIndicesFixture : GatherMainFixture 113 { Gather3dMultiIndicesFixtureGather3dMultiIndicesFixture114 Gather3dMultiIndicesFixture() : GatherMainFixture({ 2, 3 }, { 1, 2, 1, 2, 1, 0 }, { 3, 2, 3 }, { 2, 3, 2, 3 }) 115 { 116 Setup(); 117 } 118 }; 119 120 struct Gather4dFixture : GatherMainFixture 121 { Gather4dFixtureGather4dFixture122 Gather4dFixture() : GatherMainFixture({ 3 }, { 0, 1, 3 }, { 5, 4, 3, 2 }, { 3, 4, 3, 2 }) 123 { 124 Setup(); 125 } 126 }; 127 128 TEST_CASE_FIXTURE(GatherScalarFixture, "GatherScalarTest") 129 { 130 RunTest<1, float>({{"input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f }}}, 131 {{"output", { 1.0f }}}); 132 } 133 134 TEST_CASE_FIXTURE(Gather1dFixture, "Gather1dTest") 135 { 136 RunTest<1, float>({{"input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f }}}, 137 {{"output", { 1.0f, 3.0f, 2.0f, 6.0f }}}); 138 } 139 140 TEST_CASE_FIXTURE(Gather2dFixture, "Gather2dTest") 141 { 142 RunTest<2, float>({{"input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f }}}, 143 {{"output", { 3.0f, 4.0f, 7.0f, 8.0f, 9.0f, 10.0f }}}); 144 } 145 146 TEST_CASE_FIXTURE(Gather3dMultiIndicesFixture, "Gather3dMultiIndicesTest") 147 { 148 RunTest<3, float>({{"input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 149 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 150 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f }}}, 151 {{"output", { 7.0f, 8.0f, 9.0f, 152 10.0f, 11.0f, 12.0f, 153 13.0f, 14.0f, 15.0f, 154 16.0f, 17.0f, 18.0f, 155 7.0f, 8.0f, 9.0f, 156 10.0f, 11.0f, 12.0f, 157 13.0f, 14.0f, 15.0f, 158 16.0f, 17.0f, 18.0f, 159 7.0f, 8.0f, 9.0f, 160 10.0f, 11.0f, 12.0f, 161 1.0f, 2.0f, 3.0f, 162 4.0f, 5.0f, 6.0f }}}); 163 } 164 165 TEST_CASE_FIXTURE(Gather4dFixture, "Gather4dTest") 166 { 167 RunTest<4, float>({{"input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 168 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 169 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 170 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 171 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 172 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 173 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 174 36.0f, 37.0f, 38.0f, 39.0f, 40.0f, 175 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 176 46.0f, 47.0f, 48.0f, 49.0f, 50.0f, 177 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 178 56.0f, 57.0f, 58.0f, 59.0f, 60.0f, 179 61.0f, 62.0f, 63.0f, 64.0f, 65.0f, 180 66.0f, 67.0f, 68.0f, 69.0f, 70.0f, 181 71.0f, 72.0f, 73.0f, 74.0f, 75.0f, 182 76.0f, 77.0f, 78.0f, 79.0f, 80.0f, 183 81.0f, 82.0f, 83.0f, 84.0f, 85.0f, 184 86.0f, 87.0f, 88.0f, 89.0f, 90.0f, 185 91.0f, 92.0f, 93.0f, 94.0f, 95.0f, 186 96.0f, 97.0f, 98.0f, 99.0f, 100.0f, 187 101.0f, 102.0f, 103.0f, 104.0f, 105.0f, 188 106.0f, 107.0f, 108.0f, 109.0f, 110.0f, 189 111.0f, 112.0f, 113.0f, 114.0f, 115.0f, 190 116.0f, 117.0f, 118.0f, 119.0f, 120.0f }}}, 191 {{"output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 192 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 193 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 194 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 195 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 196 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 197 37.0f, 38.0f, 39.0f, 40.0f, 41.0f, 42.0f, 198 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f, 199 73.0f, 74.0f, 75.0f, 76.0f, 77.0f, 78.0f, 200 79.0f, 80.0f, 81.0f, 82.0f, 83.0f, 84.0f, 201 85.0f, 86.0f, 87.0f, 88.0f, 89.0f, 90.0f, 202 91.0f, 92.0f, 93.0f, 94.0f, 95.0f, 96.0f }}}); 203 } 204 205 struct GatherRawDataFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser> 206 { GatherRawDataFixtureGatherRawDataFixture207 GatherRawDataFixture() 208 { 209 m_Prototext = R"( 210 ir_version: 8 211 producer_name: "onnx-example" 212 graph { 213 node { 214 output: "indices" 215 op_type: "Constant" 216 attribute { 217 name: "value" 218 t { 219 dims: 3 220 data_type: 7 221 raw_data: 222 "\000\000\000\000\000\000\000\000\001\000\000\000\000\000\000\000\003\000\000\000\000\000\000\000" 223 name: "value" 224 } 225 type: TENSOR 226 } 227 } 228 node { 229 input: "input" 230 input: "indices" 231 output: "output" 232 op_type: "Gather" 233 attribute { 234 name: "axis" 235 i: 0 236 type: INT 237 } 238 } 239 name: "gather-model" 240 input { 241 name: "input" 242 type { 243 tensor_type { 244 elem_type: 1 245 shape { 246 dim { 247 dim_value: 5 248 } 249 dim { 250 dim_value: 4 251 } 252 dim { 253 dim_value: 3 254 } 255 dim { 256 dim_value: 2 257 } 258 } 259 } 260 } 261 } 262 output { 263 name: "output" 264 type { 265 tensor_type { 266 elem_type: 1 267 shape { 268 dim { 269 dim_value: 3 270 } 271 dim { 272 dim_value: 4 273 } 274 dim { 275 dim_value: 3 276 } 277 dim { 278 dim_value: 2 279 } 280 } 281 } 282 } 283 } 284 })"; 285 Setup(); 286 } 287 }; 288 289 TEST_CASE_FIXTURE(GatherRawDataFixture, "GatherRawDataTest") 290 { 291 RunTest<4, float>({{"input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 292 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 293 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 294 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 295 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 296 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 297 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 298 36.0f, 37.0f, 38.0f, 39.0f, 40.0f, 299 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 300 46.0f, 47.0f, 48.0f, 49.0f, 50.0f, 301 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 302 56.0f, 57.0f, 58.0f, 59.0f, 60.0f, 303 61.0f, 62.0f, 63.0f, 64.0f, 65.0f, 304 66.0f, 67.0f, 68.0f, 69.0f, 70.0f, 305 71.0f, 72.0f, 73.0f, 74.0f, 75.0f, 306 76.0f, 77.0f, 78.0f, 79.0f, 80.0f, 307 81.0f, 82.0f, 83.0f, 84.0f, 85.0f, 308 86.0f, 87.0f, 88.0f, 89.0f, 90.0f, 309 91.0f, 92.0f, 93.0f, 94.0f, 95.0f, 310 96.0f, 97.0f, 98.0f, 99.0f, 100.0f, 311 101.0f, 102.0f, 103.0f, 104.0f, 105.0f, 312 106.0f, 107.0f, 108.0f, 109.0f, 110.0f, 313 111.0f, 112.0f, 113.0f, 114.0f, 115.0f, 314 116.0f, 117.0f, 118.0f, 119.0f, 120.0f }}}, 315 {{"output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 316 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 317 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 318 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 319 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 320 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 321 37.0f, 38.0f, 39.0f, 40.0f, 41.0f, 42.0f, 322 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f, 323 73.0f, 74.0f, 75.0f, 76.0f, 77.0f, 78.0f, 324 79.0f, 80.0f, 81.0f, 82.0f, 83.0f, 84.0f, 325 85.0f, 86.0f, 87.0f, 88.0f, 89.0f, 90.0f, 326 91.0f, 92.0f, 93.0f, 94.0f, 95.0f, 96.0f }}}); 327 } 328 329 } 330