1 //
2 // Copyright © 2017,2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "Gather.hpp"
7
8 #include <armnn/backends/WorkloadData.hpp>
9 #include <armnn/utility/NumericCast.hpp>
10
11 namespace armnn
12 {
13
Gather(const TensorInfo & paramsInfo,const TensorInfo & indicesInfo,const TensorInfo & outputInfo,Decoder<float> & params,const int32_t * indices,Encoder<float> & output,const int32_t axis_int)14 void Gather(const TensorInfo& paramsInfo,
15 const TensorInfo& indicesInfo,
16 const TensorInfo& outputInfo,
17 Decoder<float>& params,
18 const int32_t* indices,
19 Encoder<float>& output,
20 const int32_t axis_int)
21 {
22 IgnoreUnused(outputInfo);
23
24 const int paramsRank = static_cast<int>(paramsInfo.GetNumDimensions());
25 ARMNN_ASSERT(-1 * paramsRank <= axis_int && axis_int < paramsRank);
26 const unsigned int axis = (axis_int < 0) ? static_cast<unsigned int>(paramsRank + axis_int)
27 : static_cast<unsigned int>(axis_int);
28
29 const TensorShape& paramsShape = paramsInfo.GetShape();
30
31 // Product of all dimensions to the left side of the axis
32 unsigned int paramsOuterProduct = 1;
33 for (unsigned int i = 0; i < axis; ++i)
34 {
35 paramsOuterProduct *= paramsShape[i];
36 }
37 // Product of all dimensions to the right side of the axis
38 unsigned int paramsInnerProduct = 1;
39 for (unsigned int k = 1 + axis; k < paramsInfo.GetNumDimensions(); ++k)
40 {
41 paramsInnerProduct *= paramsShape[k];
42 }
43
44 unsigned int offset = 0;
45 unsigned int outIndex = 0;
46 for (unsigned int i = 0; i < paramsOuterProduct; ++i)
47 {
48 for (unsigned int j = 0; j < indicesInfo.GetNumElements(); ++j)
49 {
50 unsigned int index = armnn::numeric_cast<unsigned int>(indices[j]);
51 ARMNN_ASSERT(indices[j] >= 0 && index < paramsShape[axis]);
52
53 unsigned int startOffset = (paramsInnerProduct * index) + offset;
54 unsigned int endOffset = startOffset + paramsInnerProduct;
55
56 for (unsigned int k = startOffset; k < endOffset; ++k)
57 {
58 params[k];
59 float outputValue = params.Get();
60 output[outIndex];
61 output.Set(outputValue);
62 ++outIndex;
63 }
64 }
65 offset += paramsShape[axis] * paramsInnerProduct;
66 }
67
68 ARMNN_ASSERT(outIndex == outputInfo.GetNumElements());
69 }
70
71 } //namespace armnn