xref: /aosp_15_r20/external/armnn/src/armnnDeserializer/test/DeserializeRsqrt.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "ParserFlatbuffersSerializeFixture.hpp"
7*89c4ff92SAndroid Build Coastguard Worker #include <armnnDeserializer/IDeserializer.hpp>
8*89c4ff92SAndroid Build Coastguard Worker 
9*89c4ff92SAndroid Build Coastguard Worker #include <string>
10*89c4ff92SAndroid Build Coastguard Worker 
11*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("Deserializer_Rsqrt")
12*89c4ff92SAndroid Build Coastguard Worker {
13*89c4ff92SAndroid Build Coastguard Worker struct RsqrtFixture : public ParserFlatbuffersSerializeFixture
14*89c4ff92SAndroid Build Coastguard Worker {
RsqrtFixtureRsqrtFixture15*89c4ff92SAndroid Build Coastguard Worker     explicit RsqrtFixture(const std::string & inputShape,
16*89c4ff92SAndroid Build Coastguard Worker                           const std::string & outputShape,
17*89c4ff92SAndroid Build Coastguard Worker                           const std::string & dataType)
18*89c4ff92SAndroid Build Coastguard Worker     {
19*89c4ff92SAndroid Build Coastguard Worker         m_JsonString = R"(
20*89c4ff92SAndroid Build Coastguard Worker         {
21*89c4ff92SAndroid Build Coastguard Worker                 inputIds: [0],
22*89c4ff92SAndroid Build Coastguard Worker                 outputIds: [2],
23*89c4ff92SAndroid Build Coastguard Worker                 layers: [
24*89c4ff92SAndroid Build Coastguard Worker                 {
25*89c4ff92SAndroid Build Coastguard Worker                     layer_type: "InputLayer",
26*89c4ff92SAndroid Build Coastguard Worker                     layer: {
27*89c4ff92SAndroid Build Coastguard Worker                           base: {
28*89c4ff92SAndroid Build Coastguard Worker                                 layerBindingId: 0,
29*89c4ff92SAndroid Build Coastguard Worker                                 base: {
30*89c4ff92SAndroid Build Coastguard Worker                                     index: 0,
31*89c4ff92SAndroid Build Coastguard Worker                                     layerName: "InputLayer",
32*89c4ff92SAndroid Build Coastguard Worker                                     layerType: "Input",
33*89c4ff92SAndroid Build Coastguard Worker                                     inputSlots: [{
34*89c4ff92SAndroid Build Coastguard Worker                                         index: 0,
35*89c4ff92SAndroid Build Coastguard Worker                                         connection: {sourceLayerIndex:0, outputSlotIndex:0 },
36*89c4ff92SAndroid Build Coastguard Worker                                     }],
37*89c4ff92SAndroid Build Coastguard Worker                                     outputSlots: [ {
38*89c4ff92SAndroid Build Coastguard Worker                                         index: 0,
39*89c4ff92SAndroid Build Coastguard Worker                                         tensorInfo: {
40*89c4ff92SAndroid Build Coastguard Worker                                             dimensions: )" + inputShape + R"(,
41*89c4ff92SAndroid Build Coastguard Worker                                             dataType: )" + dataType + R"(
42*89c4ff92SAndroid Build Coastguard Worker                                         },
43*89c4ff92SAndroid Build Coastguard Worker                                     }],
44*89c4ff92SAndroid Build Coastguard Worker                                  },}},
45*89c4ff92SAndroid Build Coastguard Worker                 },
46*89c4ff92SAndroid Build Coastguard Worker                 {
47*89c4ff92SAndroid Build Coastguard Worker                 layer_type: "RsqrtLayer",
48*89c4ff92SAndroid Build Coastguard Worker                 layer : {
49*89c4ff92SAndroid Build Coastguard Worker                         base: {
50*89c4ff92SAndroid Build Coastguard Worker                              index:1,
51*89c4ff92SAndroid Build Coastguard Worker                              layerName: "RsqrtLayer",
52*89c4ff92SAndroid Build Coastguard Worker                              layerType: "Rsqrt",
53*89c4ff92SAndroid Build Coastguard Worker                              inputSlots: [
54*89c4ff92SAndroid Build Coastguard Worker                                             {
55*89c4ff92SAndroid Build Coastguard Worker                                              index: 0,
56*89c4ff92SAndroid Build Coastguard Worker                                              connection: {sourceLayerIndex:0, outputSlotIndex:0 },
57*89c4ff92SAndroid Build Coastguard Worker                                             }
58*89c4ff92SAndroid Build Coastguard Worker                              ],
59*89c4ff92SAndroid Build Coastguard Worker                              outputSlots: [ {
60*89c4ff92SAndroid Build Coastguard Worker                                  index: 0,
61*89c4ff92SAndroid Build Coastguard Worker                                  tensorInfo: {
62*89c4ff92SAndroid Build Coastguard Worker                                      dimensions: )" + outputShape + R"(,
63*89c4ff92SAndroid Build Coastguard Worker                                      dataType: )" + dataType + R"(
64*89c4ff92SAndroid Build Coastguard Worker                                  },
65*89c4ff92SAndroid Build Coastguard Worker                              }],
66*89c4ff92SAndroid Build Coastguard Worker                             }},
67*89c4ff92SAndroid Build Coastguard Worker                 },
68*89c4ff92SAndroid Build Coastguard Worker                 {
69*89c4ff92SAndroid Build Coastguard Worker                 layer_type: "OutputLayer",
70*89c4ff92SAndroid Build Coastguard Worker                 layer: {
71*89c4ff92SAndroid Build Coastguard Worker                         base:{
72*89c4ff92SAndroid Build Coastguard Worker                               layerBindingId: 0,
73*89c4ff92SAndroid Build Coastguard Worker                               base: {
74*89c4ff92SAndroid Build Coastguard Worker                                     index: 2,
75*89c4ff92SAndroid Build Coastguard Worker                                     layerName: "OutputLayer",
76*89c4ff92SAndroid Build Coastguard Worker                                     layerType: "Output",
77*89c4ff92SAndroid Build Coastguard Worker                                     inputSlots: [{
78*89c4ff92SAndroid Build Coastguard Worker                                         index: 0,
79*89c4ff92SAndroid Build Coastguard Worker                                         connection: {sourceLayerIndex:1, outputSlotIndex:0 },
80*89c4ff92SAndroid Build Coastguard Worker                                     }],
81*89c4ff92SAndroid Build Coastguard Worker                                     outputSlots: [ {
82*89c4ff92SAndroid Build Coastguard Worker                                         index: 0,
83*89c4ff92SAndroid Build Coastguard Worker                                         tensorInfo: {
84*89c4ff92SAndroid Build Coastguard Worker                                             dimensions: )" + outputShape + R"(,
85*89c4ff92SAndroid Build Coastguard Worker                                             dataType: )" + dataType + R"(
86*89c4ff92SAndroid Build Coastguard Worker                                         },
87*89c4ff92SAndroid Build Coastguard Worker                                 }],
88*89c4ff92SAndroid Build Coastguard Worker                             }}},
89*89c4ff92SAndroid Build Coastguard Worker                 }]
90*89c4ff92SAndroid Build Coastguard Worker          }
91*89c4ff92SAndroid Build Coastguard Worker         )";
92*89c4ff92SAndroid Build Coastguard Worker         Setup();
93*89c4ff92SAndroid Build Coastguard Worker     }
94*89c4ff92SAndroid Build Coastguard Worker };
95*89c4ff92SAndroid Build Coastguard Worker 
96*89c4ff92SAndroid Build Coastguard Worker 
97*89c4ff92SAndroid Build Coastguard Worker struct Rsqrt2dFixture : RsqrtFixture
98*89c4ff92SAndroid Build Coastguard Worker {
Rsqrt2dFixtureRsqrt2dFixture99*89c4ff92SAndroid Build Coastguard Worker     Rsqrt2dFixture() : RsqrtFixture("[ 2, 2 ]",
100*89c4ff92SAndroid Build Coastguard Worker                                     "[ 2, 2 ]",
101*89c4ff92SAndroid Build Coastguard Worker                                     "Float32") {}
102*89c4ff92SAndroid Build Coastguard Worker };
103*89c4ff92SAndroid Build Coastguard Worker 
104*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(Rsqrt2dFixture, "Rsqrt2d")
105*89c4ff92SAndroid Build Coastguard Worker {
106*89c4ff92SAndroid Build Coastguard Worker   RunTest<2, armnn::DataType::Float32>(
107*89c4ff92SAndroid Build Coastguard Worker       0,
108*89c4ff92SAndroid Build Coastguard Worker       {{"InputLayer", { 1.0f,  4.0f,
109*89c4ff92SAndroid Build Coastguard Worker                         16.0f, 25.0f }}},
110*89c4ff92SAndroid Build Coastguard Worker       {{"OutputLayer",{ 1.0f,  0.5f,
111*89c4ff92SAndroid Build Coastguard Worker                         0.25f, 0.2f }}});
112*89c4ff92SAndroid Build Coastguard Worker }
113*89c4ff92SAndroid Build Coastguard Worker 
114*89c4ff92SAndroid Build Coastguard Worker 
115*89c4ff92SAndroid Build Coastguard Worker }
116