xref: /aosp_15_r20/external/armnn/src/armnnDeserializer/test/DeserializeChannelShuffle.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "ParserFlatbuffersSerializeFixture.hpp"
7*89c4ff92SAndroid Build Coastguard Worker #include <armnnDeserializer/IDeserializer.hpp>
8*89c4ff92SAndroid Build Coastguard Worker 
9*89c4ff92SAndroid Build Coastguard Worker #include <string>
10*89c4ff92SAndroid Build Coastguard Worker 
11*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("Deserializer_ChannelShuffle")
12*89c4ff92SAndroid Build Coastguard Worker {
13*89c4ff92SAndroid Build Coastguard Worker struct ChannelShuffleFixture : public ParserFlatbuffersSerializeFixture
14*89c4ff92SAndroid Build Coastguard Worker {
ChannelShuffleFixtureChannelShuffleFixture15*89c4ff92SAndroid Build Coastguard Worker     explicit ChannelShuffleFixture()
16*89c4ff92SAndroid Build Coastguard Worker     {
17*89c4ff92SAndroid Build Coastguard Worker         m_JsonString = R"(
18*89c4ff92SAndroid Build Coastguard Worker         {
19*89c4ff92SAndroid Build Coastguard Worker           layers: [
20*89c4ff92SAndroid Build Coastguard Worker             {
21*89c4ff92SAndroid Build Coastguard Worker               layer_type: "InputLayer",
22*89c4ff92SAndroid Build Coastguard Worker               layer: {
23*89c4ff92SAndroid Build Coastguard Worker                 base: {
24*89c4ff92SAndroid Build Coastguard Worker                   base: {
25*89c4ff92SAndroid Build Coastguard Worker                     layerName: "InputLayer",
26*89c4ff92SAndroid Build Coastguard Worker                     layerType: "Input",
27*89c4ff92SAndroid Build Coastguard Worker                     inputSlots: [
28*89c4ff92SAndroid Build Coastguard Worker 
29*89c4ff92SAndroid Build Coastguard Worker                     ],
30*89c4ff92SAndroid Build Coastguard Worker                     outputSlots: [
31*89c4ff92SAndroid Build Coastguard Worker                       {
32*89c4ff92SAndroid Build Coastguard Worker                         tensorInfo: {
33*89c4ff92SAndroid Build Coastguard Worker                           dimensions: [
34*89c4ff92SAndroid Build Coastguard Worker                             3,
35*89c4ff92SAndroid Build Coastguard Worker                             12
36*89c4ff92SAndroid Build Coastguard Worker                           ],
37*89c4ff92SAndroid Build Coastguard Worker                           dataType: "Float32",
38*89c4ff92SAndroid Build Coastguard Worker                           quantizationScale: 0.0,
39*89c4ff92SAndroid Build Coastguard Worker                           dimensionSpecificity: [
40*89c4ff92SAndroid Build Coastguard Worker                             true,
41*89c4ff92SAndroid Build Coastguard Worker                             true
42*89c4ff92SAndroid Build Coastguard Worker                           ]
43*89c4ff92SAndroid Build Coastguard Worker                         }
44*89c4ff92SAndroid Build Coastguard Worker                       }
45*89c4ff92SAndroid Build Coastguard Worker                     ]
46*89c4ff92SAndroid Build Coastguard Worker                   }
47*89c4ff92SAndroid Build Coastguard Worker                 }
48*89c4ff92SAndroid Build Coastguard Worker               }
49*89c4ff92SAndroid Build Coastguard Worker             },
50*89c4ff92SAndroid Build Coastguard Worker             {
51*89c4ff92SAndroid Build Coastguard Worker               layer_type: "ChannelShuffleLayer",
52*89c4ff92SAndroid Build Coastguard Worker               layer: {
53*89c4ff92SAndroid Build Coastguard Worker                 base: {
54*89c4ff92SAndroid Build Coastguard Worker                   index: 1,
55*89c4ff92SAndroid Build Coastguard Worker                   layerName: "channelShuffle",
56*89c4ff92SAndroid Build Coastguard Worker                   layerType: "ChannelShuffle",
57*89c4ff92SAndroid Build Coastguard Worker                   inputSlots: [
58*89c4ff92SAndroid Build Coastguard Worker                     {
59*89c4ff92SAndroid Build Coastguard Worker                       connection: {
60*89c4ff92SAndroid Build Coastguard Worker                         sourceLayerIndex: 0,
61*89c4ff92SAndroid Build Coastguard Worker                         outputSlotIndex: 0
62*89c4ff92SAndroid Build Coastguard Worker                       }
63*89c4ff92SAndroid Build Coastguard Worker                     }
64*89c4ff92SAndroid Build Coastguard Worker                   ],
65*89c4ff92SAndroid Build Coastguard Worker                   outputSlots: [
66*89c4ff92SAndroid Build Coastguard Worker                     {
67*89c4ff92SAndroid Build Coastguard Worker                       tensorInfo: {
68*89c4ff92SAndroid Build Coastguard Worker                         dimensions: [
69*89c4ff92SAndroid Build Coastguard Worker                           3,
70*89c4ff92SAndroid Build Coastguard Worker                           12
71*89c4ff92SAndroid Build Coastguard Worker                         ],
72*89c4ff92SAndroid Build Coastguard Worker                         dataType: "Float32",
73*89c4ff92SAndroid Build Coastguard Worker                         quantizationScale: 0.0,
74*89c4ff92SAndroid Build Coastguard Worker                         dimensionSpecificity: [
75*89c4ff92SAndroid Build Coastguard Worker                           true,
76*89c4ff92SAndroid Build Coastguard Worker                           true
77*89c4ff92SAndroid Build Coastguard Worker                         ]
78*89c4ff92SAndroid Build Coastguard Worker                       }
79*89c4ff92SAndroid Build Coastguard Worker                     }
80*89c4ff92SAndroid Build Coastguard Worker                   ]
81*89c4ff92SAndroid Build Coastguard Worker                 },
82*89c4ff92SAndroid Build Coastguard Worker                 descriptor: {
83*89c4ff92SAndroid Build Coastguard Worker                   axis: 1,
84*89c4ff92SAndroid Build Coastguard Worker                   numGroups: 3
85*89c4ff92SAndroid Build Coastguard Worker                 }
86*89c4ff92SAndroid Build Coastguard Worker               }
87*89c4ff92SAndroid Build Coastguard Worker             },
88*89c4ff92SAndroid Build Coastguard Worker             {
89*89c4ff92SAndroid Build Coastguard Worker               layer_type: "OutputLayer",
90*89c4ff92SAndroid Build Coastguard Worker               layer: {
91*89c4ff92SAndroid Build Coastguard Worker                 base: {
92*89c4ff92SAndroid Build Coastguard Worker                   base: {
93*89c4ff92SAndroid Build Coastguard Worker                     index: 2,
94*89c4ff92SAndroid Build Coastguard Worker                     layerName: "OutputLayer",
95*89c4ff92SAndroid Build Coastguard Worker                     layerType: "Output",
96*89c4ff92SAndroid Build Coastguard Worker                     inputSlots: [
97*89c4ff92SAndroid Build Coastguard Worker                       {
98*89c4ff92SAndroid Build Coastguard Worker                         connection: {
99*89c4ff92SAndroid Build Coastguard Worker                           sourceLayerIndex: 1,
100*89c4ff92SAndroid Build Coastguard Worker                           outputSlotIndex: 0
101*89c4ff92SAndroid Build Coastguard Worker                         }
102*89c4ff92SAndroid Build Coastguard Worker                       }
103*89c4ff92SAndroid Build Coastguard Worker                     ],
104*89c4ff92SAndroid Build Coastguard Worker                     outputSlots: [
105*89c4ff92SAndroid Build Coastguard Worker 
106*89c4ff92SAndroid Build Coastguard Worker                     ]
107*89c4ff92SAndroid Build Coastguard Worker                   }
108*89c4ff92SAndroid Build Coastguard Worker                 }
109*89c4ff92SAndroid Build Coastguard Worker               }
110*89c4ff92SAndroid Build Coastguard Worker             }
111*89c4ff92SAndroid Build Coastguard Worker           ],
112*89c4ff92SAndroid Build Coastguard Worker           inputIds: [
113*89c4ff92SAndroid Build Coastguard Worker             0
114*89c4ff92SAndroid Build Coastguard Worker           ],
115*89c4ff92SAndroid Build Coastguard Worker           outputIds: [
116*89c4ff92SAndroid Build Coastguard Worker             0
117*89c4ff92SAndroid Build Coastguard Worker           ],
118*89c4ff92SAndroid Build Coastguard Worker           featureVersions: {
119*89c4ff92SAndroid Build Coastguard Worker             bindingIdsScheme: 1,
120*89c4ff92SAndroid Build Coastguard Worker             weightsLayoutScheme: 1,
121*89c4ff92SAndroid Build Coastguard Worker             constantTensorsAsInputs: 1
122*89c4ff92SAndroid Build Coastguard Worker           }
123*89c4ff92SAndroid Build Coastguard Worker         }
124*89c4ff92SAndroid Build Coastguard Worker     )";
125*89c4ff92SAndroid Build Coastguard Worker     SetupSingleInputSingleOutput("InputLayer", "OutputLayer");
126*89c4ff92SAndroid Build Coastguard Worker     }
127*89c4ff92SAndroid Build Coastguard Worker };
128*89c4ff92SAndroid Build Coastguard Worker 
129*89c4ff92SAndroid Build Coastguard Worker struct SimpleChannelShuffleFixtureFloat32 : ChannelShuffleFixture
130*89c4ff92SAndroid Build Coastguard Worker {
SimpleChannelShuffleFixtureFloat32SimpleChannelShuffleFixtureFloat32131*89c4ff92SAndroid Build Coastguard Worker     SimpleChannelShuffleFixtureFloat32() : ChannelShuffleFixture(){}
132*89c4ff92SAndroid Build Coastguard Worker };
133*89c4ff92SAndroid Build Coastguard Worker 
134*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(SimpleChannelShuffleFixtureFloat32, "ChannelShuffleFloat32")
135*89c4ff92SAndroid Build Coastguard Worker {
136*89c4ff92SAndroid Build Coastguard Worker     RunTest<2, armnn::DataType::Float32>(0,
137*89c4ff92SAndroid Build Coastguard Worker                                          {{"InputLayer",
138*89c4ff92SAndroid Build Coastguard Worker                                            {  0, 1, 2, 3,        4, 5, 6, 7,       8, 9, 10, 11,
139*89c4ff92SAndroid Build Coastguard Worker                                             12, 13, 14, 15,   16, 17, 18, 19,   20, 21, 22, 23,
140*89c4ff92SAndroid Build Coastguard Worker                                             24, 25, 26, 27,   28, 29, 30, 31,   32, 33, 34, 35}}},
141*89c4ff92SAndroid Build Coastguard Worker                                          {{"OutputLayer",
142*89c4ff92SAndroid Build Coastguard Worker                                            { 0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11,
143*89c4ff92SAndroid Build Coastguard Worker                                             12, 16, 20, 13, 17, 21, 14, 18, 22, 15, 19, 23,
144*89c4ff92SAndroid Build Coastguard Worker                                             24, 28, 32, 25, 29, 33, 26, 30, 34, 27, 31, 35 }}});
145*89c4ff92SAndroid Build Coastguard Worker }
146*89c4ff92SAndroid Build Coastguard Worker }