xref: /aosp_15_r20/external/armnn/src/backends/backendsCommon/test/GatherEndToEndTestImpl.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #pragma once
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include <CommonTestUtils.hpp>
9*89c4ff92SAndroid Build Coastguard Worker 
10*89c4ff92SAndroid Build Coastguard Worker #include <armnn/INetwork.hpp>
11*89c4ff92SAndroid Build Coastguard Worker #include <ResolveType.hpp>
12*89c4ff92SAndroid Build Coastguard Worker 
13*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
14*89c4ff92SAndroid Build Coastguard Worker 
15*89c4ff92SAndroid Build Coastguard Worker namespace{
16*89c4ff92SAndroid Build Coastguard Worker 
CreateGatherNetwork(const armnn::TensorInfo & paramsInfo,const armnn::TensorInfo & indicesInfo,const armnn::TensorInfo & outputInfo,const std::vector<int32_t> & indicesData)17*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr CreateGatherNetwork(const armnn::TensorInfo& paramsInfo,
18*89c4ff92SAndroid Build Coastguard Worker                                        const armnn::TensorInfo& indicesInfo,
19*89c4ff92SAndroid Build Coastguard Worker                                        const armnn::TensorInfo& outputInfo,
20*89c4ff92SAndroid Build Coastguard Worker                                        const std::vector<int32_t>& indicesData)
21*89c4ff92SAndroid Build Coastguard Worker {
22*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr net(armnn::INetwork::Create());
23*89c4ff92SAndroid Build Coastguard Worker 
24*89c4ff92SAndroid Build Coastguard Worker     armnn::GatherDescriptor descriptor;
25*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* paramsLayer = net->AddInputLayer(0);
26*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* indicesLayer = net->AddConstantLayer(armnn::ConstTensor(indicesInfo, indicesData));
27*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* gatherLayer = net->AddGatherLayer(descriptor, "gather");
28*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* outputLayer = net->AddOutputLayer(0, "output");
29*89c4ff92SAndroid Build Coastguard Worker     Connect(paramsLayer, gatherLayer, paramsInfo, 0, 0);
30*89c4ff92SAndroid Build Coastguard Worker     Connect(indicesLayer, gatherLayer, indicesInfo, 0, 1);
31*89c4ff92SAndroid Build Coastguard Worker     Connect(gatherLayer, outputLayer, outputInfo, 0, 0);
32*89c4ff92SAndroid Build Coastguard Worker 
33*89c4ff92SAndroid Build Coastguard Worker     return net;
34*89c4ff92SAndroid Build Coastguard Worker }
35*89c4ff92SAndroid Build Coastguard Worker 
36*89c4ff92SAndroid Build Coastguard Worker template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
GatherEndToEnd(const std::vector<BackendId> & backends)37*89c4ff92SAndroid Build Coastguard Worker void GatherEndToEnd(const std::vector<BackendId>& backends)
38*89c4ff92SAndroid Build Coastguard Worker {
39*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo paramsInfo({ 8 }, ArmnnType);
40*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo indicesInfo({ 3 }, armnn::DataType::Signed32);
41*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo outputInfo({ 3 }, ArmnnType);
42*89c4ff92SAndroid Build Coastguard Worker 
43*89c4ff92SAndroid Build Coastguard Worker     paramsInfo.SetQuantizationScale(1.0f);
44*89c4ff92SAndroid Build Coastguard Worker     paramsInfo.SetQuantizationOffset(0);
45*89c4ff92SAndroid Build Coastguard Worker     paramsInfo.SetConstant(true);
46*89c4ff92SAndroid Build Coastguard Worker     indicesInfo.SetConstant(true);
47*89c4ff92SAndroid Build Coastguard Worker     outputInfo.SetQuantizationScale(1.0f);
48*89c4ff92SAndroid Build Coastguard Worker     outputInfo.SetQuantizationOffset(0);
49*89c4ff92SAndroid Build Coastguard Worker 
50*89c4ff92SAndroid Build Coastguard Worker     // Creates structures for input & output.
51*89c4ff92SAndroid Build Coastguard Worker     std::vector<T> paramsData{
52*89c4ff92SAndroid Build Coastguard Worker         1, 2, 3, 4, 5, 6, 7, 8
53*89c4ff92SAndroid Build Coastguard Worker     };
54*89c4ff92SAndroid Build Coastguard Worker 
55*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> indicesData{
56*89c4ff92SAndroid Build Coastguard Worker         7, 6, 5
57*89c4ff92SAndroid Build Coastguard Worker     };
58*89c4ff92SAndroid Build Coastguard Worker 
59*89c4ff92SAndroid Build Coastguard Worker     std::vector<T> expectedOutput{
60*89c4ff92SAndroid Build Coastguard Worker         8, 7, 6
61*89c4ff92SAndroid Build Coastguard Worker     };
62*89c4ff92SAndroid Build Coastguard Worker 
63*89c4ff92SAndroid Build Coastguard Worker     // Builds up the structure of the network
64*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr net = CreateGatherNetwork(paramsInfo, indicesInfo, outputInfo, indicesData);
65*89c4ff92SAndroid Build Coastguard Worker 
66*89c4ff92SAndroid Build Coastguard Worker     CHECK(net);
67*89c4ff92SAndroid Build Coastguard Worker 
68*89c4ff92SAndroid Build Coastguard Worker     std::map<int, std::vector<T>> inputTensorData = {{ 0, paramsData }};
69*89c4ff92SAndroid Build Coastguard Worker     std::map<int, std::vector<T>> expectedOutputData = {{ 0, expectedOutput }};
70*89c4ff92SAndroid Build Coastguard Worker 
71*89c4ff92SAndroid Build Coastguard Worker     EndToEndLayerTestImpl<ArmnnType, ArmnnType>(move(net), inputTensorData, expectedOutputData, backends);
72*89c4ff92SAndroid Build Coastguard Worker }
73*89c4ff92SAndroid Build Coastguard Worker 
74*89c4ff92SAndroid Build Coastguard Worker template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
GatherMultiDimEndToEnd(const std::vector<BackendId> & backends)75*89c4ff92SAndroid Build Coastguard Worker void GatherMultiDimEndToEnd(const std::vector<BackendId>& backends)
76*89c4ff92SAndroid Build Coastguard Worker {
77*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo paramsInfo({ 3, 2, 3}, ArmnnType);
78*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo indicesInfo({ 2, 3 }, armnn::DataType::Signed32);
79*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo outputInfo({ 2, 3, 2, 3 }, ArmnnType);
80*89c4ff92SAndroid Build Coastguard Worker 
81*89c4ff92SAndroid Build Coastguard Worker     paramsInfo.SetQuantizationScale(1.0f);
82*89c4ff92SAndroid Build Coastguard Worker     paramsInfo.SetQuantizationOffset(0);
83*89c4ff92SAndroid Build Coastguard Worker     paramsInfo.SetConstant(true);
84*89c4ff92SAndroid Build Coastguard Worker     indicesInfo.SetConstant(true);
85*89c4ff92SAndroid Build Coastguard Worker     outputInfo.SetQuantizationScale(1.0f);
86*89c4ff92SAndroid Build Coastguard Worker     outputInfo.SetQuantizationOffset(0);
87*89c4ff92SAndroid Build Coastguard Worker 
88*89c4ff92SAndroid Build Coastguard Worker     // Creates structures for input & output.
89*89c4ff92SAndroid Build Coastguard Worker     std::vector<T> paramsData{
90*89c4ff92SAndroid Build Coastguard Worker          1,  2,  3,
91*89c4ff92SAndroid Build Coastguard Worker          4,  5,  6,
92*89c4ff92SAndroid Build Coastguard Worker 
93*89c4ff92SAndroid Build Coastguard Worker          7,  8,  9,
94*89c4ff92SAndroid Build Coastguard Worker         10, 11, 12,
95*89c4ff92SAndroid Build Coastguard Worker 
96*89c4ff92SAndroid Build Coastguard Worker         13, 14, 15,
97*89c4ff92SAndroid Build Coastguard Worker         16, 17, 18
98*89c4ff92SAndroid Build Coastguard Worker     };
99*89c4ff92SAndroid Build Coastguard Worker 
100*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> indicesData{
101*89c4ff92SAndroid Build Coastguard Worker         1, 2, 1,
102*89c4ff92SAndroid Build Coastguard Worker         2, 1, 0
103*89c4ff92SAndroid Build Coastguard Worker     };
104*89c4ff92SAndroid Build Coastguard Worker 
105*89c4ff92SAndroid Build Coastguard Worker     std::vector<T> expectedOutput{
106*89c4ff92SAndroid Build Coastguard Worker          7,  8,  9,
107*89c4ff92SAndroid Build Coastguard Worker         10, 11, 12,
108*89c4ff92SAndroid Build Coastguard Worker         13, 14, 15,
109*89c4ff92SAndroid Build Coastguard Worker         16, 17, 18,
110*89c4ff92SAndroid Build Coastguard Worker          7,  8,  9,
111*89c4ff92SAndroid Build Coastguard Worker         10, 11, 12,
112*89c4ff92SAndroid Build Coastguard Worker 
113*89c4ff92SAndroid Build Coastguard Worker         13, 14, 15,
114*89c4ff92SAndroid Build Coastguard Worker         16, 17, 18,
115*89c4ff92SAndroid Build Coastguard Worker          7,  8,  9,
116*89c4ff92SAndroid Build Coastguard Worker         10, 11, 12,
117*89c4ff92SAndroid Build Coastguard Worker          1,  2,  3,
118*89c4ff92SAndroid Build Coastguard Worker          4,  5,  6
119*89c4ff92SAndroid Build Coastguard Worker     };
120*89c4ff92SAndroid Build Coastguard Worker 
121*89c4ff92SAndroid Build Coastguard Worker     // Builds up the structure of the network
122*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr net = CreateGatherNetwork(paramsInfo, indicesInfo, outputInfo, indicesData);
123*89c4ff92SAndroid Build Coastguard Worker 
124*89c4ff92SAndroid Build Coastguard Worker     std::map<int, std::vector<T>> inputTensorData = {{ 0, paramsData }};
125*89c4ff92SAndroid Build Coastguard Worker     std::map<int, std::vector<T>> expectedOutputData = {{ 0, expectedOutput }};
126*89c4ff92SAndroid Build Coastguard Worker 
127*89c4ff92SAndroid Build Coastguard Worker     EndToEndLayerTestImpl<ArmnnType, ArmnnType>(move(net), inputTensorData, expectedOutputData, backends);
128*89c4ff92SAndroid Build Coastguard Worker }
129*89c4ff92SAndroid Build Coastguard Worker 
130*89c4ff92SAndroid Build Coastguard Worker } // anonymous namespace
131