xref: /aosp_15_r20/external/armnn/src/backends/reference/workloads/Gather.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017,2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "Gather.hpp"
7 
8 #include <armnn/backends/WorkloadData.hpp>
9 #include <armnn/utility/NumericCast.hpp>
10 
11 namespace armnn
12 {
13 
Gather(const TensorInfo & paramsInfo,const TensorInfo & indicesInfo,const TensorInfo & outputInfo,Decoder<float> & params,const int32_t * indices,Encoder<float> & output,const int32_t axis_int)14 void Gather(const TensorInfo& paramsInfo,
15             const TensorInfo& indicesInfo,
16             const TensorInfo& outputInfo,
17             Decoder<float>& params,
18             const int32_t* indices,
19             Encoder<float>& output,
20             const int32_t axis_int)
21 {
22     IgnoreUnused(outputInfo);
23 
24     const int paramsRank = static_cast<int>(paramsInfo.GetNumDimensions());
25     ARMNN_ASSERT(-1 * paramsRank <= axis_int && axis_int < paramsRank);
26     const unsigned int axis = (axis_int < 0) ? static_cast<unsigned int>(paramsRank + axis_int)
27                                              : static_cast<unsigned int>(axis_int);
28 
29     const TensorShape& paramsShape = paramsInfo.GetShape();
30 
31     // Product of all dimensions to the left side of the axis
32     unsigned int paramsOuterProduct = 1;
33     for (unsigned int i = 0; i < axis; ++i)
34     {
35         paramsOuterProduct *= paramsShape[i];
36     }
37     // Product of all dimensions to the right side of the axis
38     unsigned int paramsInnerProduct = 1;
39     for (unsigned int k = 1 + axis; k < paramsInfo.GetNumDimensions(); ++k)
40     {
41         paramsInnerProduct *= paramsShape[k];
42     }
43 
44     unsigned int offset = 0;
45     unsigned int outIndex = 0;
46     for (unsigned int i = 0; i < paramsOuterProduct; ++i)
47     {
48         for (unsigned int j = 0; j < indicesInfo.GetNumElements(); ++j)
49         {
50             unsigned int index = armnn::numeric_cast<unsigned int>(indices[j]);
51             ARMNN_ASSERT(indices[j] >= 0 && index < paramsShape[axis]);
52 
53             unsigned int startOffset = (paramsInnerProduct * index) + offset;
54             unsigned int endOffset = startOffset + paramsInnerProduct;
55 
56             for (unsigned int k = startOffset; k < endOffset; ++k)
57             {
58                 params[k];
59                 float outputValue = params.Get();
60                 output[outIndex];
61                 output.Set(outputValue);
62                 ++outIndex;
63             }
64         }
65         offset += paramsShape[axis] * paramsInnerProduct;
66     }
67 
68     ARMNN_ASSERT(outIndex == outputInfo.GetNumElements());
69 }
70 
71 } //namespace armnn